diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d26d8ecee2..21ac8fbd45 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -61,6 +61,10 @@ jobs: python-version: 3.8 requires: 'latest' topic: ['graph'] + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: ['audio'] # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 @@ -128,6 +132,13 @@ jobs: run: | pip install '.[all]' --pre --upgrade + - name: Install audio test dependencies + if: matrix.topic[0] == 'audio' + run: | + sudo apt-get install libsndfile1 + pip install matplotlib + pip install '.[image]' --pre --upgrade + - name: Cache datasets uses: actions/cache@v2 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index 54851b160e..cb7c1cb3b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585)) +- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index d3312220d7..d050db39c5 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -4,7 +4,7 @@ {% block footer %} {{ super() }} {% endblock %} diff --git a/docs/source/index.rst b/docs/source/index.rst index cf3917f11d..2ac114009c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -40,6 +40,12 @@ Lightning Flash reference/style_transfer reference/video_classification +.. toctree:: + :maxdepth: 1 + :caption: Audio + + reference/audio_classification + .. toctree:: :maxdepth: 1 :caption: Tabular diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst new file mode 100644 index 0000000000..eb122e6995 --- /dev/null +++ b/docs/source/reference/audio_classification.rst @@ -0,0 +1,73 @@ + +.. _audio_classification: + +#################### +Audio Classification +#################### + +******** +The Task +******** + +The task of identifying what is in an audio file is called audio classification. +Typically, Audio Classification is used to identify audio files containing sounds or words. +The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty. +A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc. + +------ + +******* +Example +******* + +Let's look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset. +The dataset contains ``train``, ``val`` and ``test`` folders, and then each folder contains a **airconditioner** folder, with spectrograms generated from air-conditioner sounds, **siren** folder with spectrograms generated from siren sounds and the same goes for the other classes. + +.. code-block:: + + urban8k_images + ├── train + │ ├── air_conditioner + │ ├── car_horn + │ ├── children_playing + │ ├── dog_bark + │ ├── drilling + │ ├── engine_idling + │ ├── gun_shot + │ ├── jackhammer + │ ├── siren + │ └── street_music + ├── test + │ ├── air_conditioner + │ ├── car_horn + │ ├── children_playing + │ ├── dog_bark + │ ├── drilling + │ ├── engine_idling + │ ├── gun_shot + │ ├── jackhammer + │ ├── siren + │ └── street_music + └── val + ├── air_conditioner + ├── car_horn + ├── children_playing + ├── dog_bark + ├── drilling + ├── engine_idling + ├── gun_shot + ├── jackhammer + ├── siren + └── street_music + + ... + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.classification.data.AudioClassificationData`. +We select a pre-trained backbone to use for our :class:`~flash.image.classification.model.ImageClassifier` and fine-tune on the UrbanSound8k spectrogram images data. +We then use the trained :class:`~flash.image.classification.model.ImageClassifier` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/audio_classification.py + :language: python + :lines: 14- diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py new file mode 100644 index 0000000000..40eeaae124 --- /dev/null +++ b/flash/audio/__init__.py @@ -0,0 +1 @@ +from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 diff --git a/flash/audio/classification/__init__.py b/flash/audio/classification/__init__.py new file mode 100644 index 0000000000..476a303d49 --- /dev/null +++ b/flash/audio/classification/__init__.py @@ -0,0 +1 @@ +from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py new file mode 100644 index 0000000000..68678b2a1b --- /dev/null +++ b/flash/audio/classification/data.py @@ -0,0 +1,87 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.audio.classification.transforms import default_transforms, train_default_transforms +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Deserializer, Preprocess +from flash.core.utilities.imports import requires_extras +from flash.image.classification.data import MatplotlibVisualization +from flash.image.data import ImageDeserializer, ImagePathsDataSource + + +class AudioClassificationPreprocess(Preprocess): + + @requires_extras(["audio", "image"]) + def __init__( + self, + train_transform: Optional[Dict[str, Callable]], + val_transform: Optional[Dict[str, Callable]], + test_transform: Optional[Dict[str, Callable]], + predict_transform: Optional[Dict[str, Callable]], + spectrogram_size: Tuple[int, int] = (196, 196), + time_mask_param: int = 80, + freq_mask_param: int = 80, + deserializer: Optional['Deserializer'] = None, + ): + self.spectrogram_size = spectrogram_size + self.time_mask_param = time_mask_param + self.freq_mask_param = freq_mask_param + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.FILES: ImagePathsDataSource(), + DefaultDataSources.FOLDERS: ImagePathsDataSource() + }, + deserializer=deserializer or ImageDeserializer(), + default_data_source=DefaultDataSources.FILES, + ) + + def get_state_dict(self) -> Dict[str, Any]: + return { + **self.transforms, + "spectrogram_size": self.spectrogram_size, + "time_mask_param": self.time_mask_param, + "freq_mask_param": self.freq_mask_param, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.spectrogram_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param) + + +class AudioClassificationData(DataModule): + """Data module for audio classification.""" + + preprocess_cls = AudioClassificationPreprocess + + def set_block_viz_window(self, value: bool) -> None: + """Setter method to switch on/off matplotlib to pop up windows.""" + self.data_fetcher.block_viz_window = value + + @staticmethod + def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: + return MatplotlibVisualization(*args, **kwargs) diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py new file mode 100644 index 0000000000..02a9ed2cbc --- /dev/null +++ b/flash/audio/classification/transforms.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, Tuple + +import torch +from torch import nn + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms +from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + import torchvision + from torchvision import transforms as T + +if _TORCHAUDIO_AVAILABLE: + from torchaudio import transforms as TAudio + + +def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]: + """The default transforms for audio classification for spectrograms: resize the spectrogram, + convert the spectrogram and target to a tensor, and collate the batch.""" + return { + "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)), + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "collate": kornia_collate, + } + + +def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int, + freq_mask_param: int) -> Dict[str, Callable]: + """During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``""" + transforms = { + "post_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)) + ) + } + + return merge_transforms(default_transforms(spectrogram_size), transforms) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 9922f49eba..80c6b6188c 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -16,6 +16,7 @@ import operator import types from importlib.util import find_spec +from typing import Callable, List, Union from pkg_resources import DistributionNotFound @@ -89,6 +90,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") _TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") +_TORCHAUDIO_AVAILABLE = _module_available("torchaudio") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") @@ -108,6 +110,7 @@ def _compare_version(package: str, op, version) -> bool: _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE _AUDIO_AVAILABLE = all([ _ASTEROID_AVAILABLE, + _TORCHAUDIO_AVAILABLE, ]) _GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE @@ -123,15 +126,22 @@ def _compare_version(package: str, op, version) -> bool: } -def _requires(module_path: str, module_available: bool): +def _requires( + module_paths: Union[str, List], + module_available: Callable[[str], bool], + formatter: Callable[[List[str]], str], +): + + if not isinstance(module_paths, list): + module_paths = [module_paths] def decorator(func): - if not module_available: + if not all(module_available(module_path) for module_path in module_paths): @functools.wraps(func) def wrapper(*args, **kwargs): raise ModuleNotFoundError( - f"Required dependencies not available. Please run: pip install '{module_path}'" + f"Required dependencies not available. Please run: pip install {formatter(module_paths)}" ) return wrapper @@ -141,12 +151,14 @@ def wrapper(*args, **kwargs): return decorator -def requires(module_path: str): - return _requires(module_path, _module_available(module_path)) +def requires(module_paths: Union[str, List]): + return _requires(module_paths, _module_available, lambda module_paths: " ".join(module_paths)) -def requires_extras(extras: str): - return _requires(f"lightning-flash[{extras}]", _EXTRAS_AVAILABLE[extras]) +def requires_extras(extras: Union[str, List]): + return _requires( + extras, lambda extras: _EXTRAS_AVAILABLE[extras], lambda extras: f"'lightning-flash[{','.join(extras)}]'" + ) def lazy_import(module_name, callback=None): diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py new file mode 100644 index 0000000000..b8f0f8a312 --- /dev/null +++ b/flash_examples/audio_classification.py @@ -0,0 +1,45 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.audio import AudioClassificationData +from flash.core.data.utils import download_data +from flash.core.finetuning import FreezeUnfreeze +from flash.image import ImageClassifier + +# 1. Create the DataModule +download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data") + +datamodule = AudioClassificationData.from_folders( + train_folder="data/urban8k_images/train", + val_folder="data/urban8k_images/val", + spectrogram_size=(64, 64), +) + +# 2. Build the model. +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=3) +trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) + +# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c +predictions = model.predict([ + "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", + "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg", + "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg", +]) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("audio_classification_model.pt") diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index 03c90d99ec..e608a13b78 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1 +1,2 @@ asteroid>=0.5.1 +torchaudio diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/audio/classification/__init__.py b/tests/audio/classification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py new file mode 100644 index 0000000000..a1c0ba0677 --- /dev/null +++ b/tests/audio/classification/test_data.py @@ -0,0 +1,340 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Any, List, Tuple + +import numpy as np +import pytest +import torch +import torch.nn as nn + +from flash.audio import AudioClassificationData +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from tests.helpers.utils import _AUDIO_TESTING + +if _TORCHVISION_AVAILABLE: + import torchvision + +if _PIL_AVAILABLE: + from PIL import Image + + +def _rand_image(size: Tuple[int, int] = None): + if size is None: + _size = np.random.choice([196, 244]) + size = (_size, _size) + return Image.fromarray(np.random.randint(0, 255, (*size, 3), dtype="uint8")) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_smoke(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a_1.png") + _rand_image().save(tmpdir / "b_1.png") + + train_images = [ + str(tmpdir / "a_1.png"), + str(tmpdir / "b_1.png"), + ] + + spectrograms_data = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[1, 2], + batch_size=2, + num_workers=0, + ) + assert spectrograms_data.train_dataloader() is not None + assert spectrograms_data.val_dataloader() is None + assert spectrograms_data.test_dataloader() is None + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert sorted(list(labels.numpy())) == [1, 2] + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_list_image_paths(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + spectrograms_data = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], + batch_size=2, + num_workers=0, + ) + + # check training data + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here + assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here + + # check validation data + data = next(iter(spectrograms_data.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [1, 4] + + # check test data + data = next(iter(spectrograms_data.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [2, 5] + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") +def test_from_filepaths_visualise(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + dm = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], + batch_size=2, + num_workers=0, + ) + + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + # dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") +def test_from_filepaths_visualise_multilabel(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + + image_a = str(tmpdir / "a" / "a_1.png") + image_b = str(tmpdir / "b" / "b_1.png") + + _rand_image().save(image_a) + _rand_image().save(image_b) + + dm = AudioClassificationData.from_files( + train_files=[image_a, image_b], + train_targets=[[0, 1, 0], [0, 1, 1]], + val_files=[image_b, image_a], + val_targets=[[1, 1, 0], [0, 0, 1]], + test_files=[image_b, image_b], + test_targets=[[0, 0, 1], [1, 1, 0]], + batch_size=2, + spectrogram_size=(64, 64), + ) + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch("to_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + dm.show_val_batch("per_batch_transform") + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_splits(tmpdir): + tmpdir = Path(tmpdir) + + B, _, H, W = 2, 3, 224, 224 + img_size: Tuple[int, int] = (H, W) + + (tmpdir / "splits").mkdir() + _rand_image(img_size).save(tmpdir / "s.png") + + num_samples: int = 10 + val_split: float = .3 + + train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)] + + train_labels: List[int] = list(range(num_samples)) + + assert len(train_filepaths) == len(train_labels) + + _to_tensor = { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor) + ), + } + + def run(transform: Any = None): + dm = AudioClassificationData.from_files( + train_files=train_filepaths, + train_targets=train_labels, + train_transform=transform, + val_transform=transform, + batch_size=B, + num_workers=0, + val_split=val_split, + spectrogram_size=img_size, + ) + data = next(iter(dm.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (B, 3, H, W) + assert labels.shape == (B, ) + + run(_to_tensor) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_folders_only_train(tmpdir): + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + _rand_image().save(train_dir / "a" / "1.png") + _rand_image().save(train_dir / "a" / "2.png") + + (train_dir / "b").mkdir() + _rand_image().save(train_dir / "b" / "1.png") + _rand_image().save(train_dir / "b" / "2.png") + + spectrograms_data = AudioClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1, ) + + assert spectrograms_data.val_dataloader() is None + assert spectrograms_data.test_dataloader() is None + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_folders_train_val(tmpdir): + + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + _rand_image().save(train_dir / "a" / "1.png") + _rand_image().save(train_dir / "a" / "2.png") + + (train_dir / "b").mkdir() + _rand_image().save(train_dir / "b" / "1.png") + _rand_image().save(train_dir / "b" / "2.png") + spectrograms_data = AudioClassificationData.from_folders( + train_dir, + val_folder=train_dir, + test_folder=train_dir, + batch_size=2, + num_workers=0, + ) + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + + data = next(iter(spectrograms_data.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [0, 0] + + data = next(iter(spectrograms_data.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [0, 0] + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_multilabel(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + _rand_image().save(tmpdir / "a1.png") + _rand_image().save(tmpdir / "a2.png") + + train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")] + train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]] + valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]] + test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]] + + dm = AudioClassificationData.from_files( + train_files=train_images, + train_targets=train_labels, + val_files=train_images, + val_targets=valid_labels, + test_files=train_images, + test_targets=test_labels, + batch_size=2, + num_workers=0, + ) + + data = next(iter(dm.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 4) + + data = next(iter(dm.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 4) + torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) + + data = next(iter(dm.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 4) + torch.testing.assert_allclose(labels, torch.tensor(test_labels)) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index ec6c4bb834..56b729e36e 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -21,6 +21,7 @@ from flash.core.utilities.imports import _SKLEARN_AVAILABLE from tests.examples.utils import run_test from tests.helpers.utils import ( + _AUDIO_TESTING, _GRAPH_TESTING, _IMAGE_TESTING, _POINTCLOUD_TESTING, @@ -37,6 +38,10 @@ pytest.param( "custom_task.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") ), + pytest.param( + "audio_classification.py", + marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed") + ), pytest.param( "image_classification.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 5bb699b664..bd57cf570d 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -14,6 +14,7 @@ import os from flash.core.utilities.imports import ( + _AUDIO_AVAILABLE, _GRAPH_AVAILABLE, _IMAGE_AVAILABLE, _POINTCLOUD_AVAILABLE, @@ -30,6 +31,7 @@ _SERVE_TESTING = _SERVE_AVAILABLE _POINTCLOUD_TESTING = _POINTCLOUD_AVAILABLE _GRAPH_TESTING = _GRAPH_AVAILABLE +_AUDIO_TESTING = _AUDIO_AVAILABLE if "FLASH_TEST_TOPIC" in os.environ: topic = os.environ["FLASH_TEST_TOPIC"] @@ -40,3 +42,4 @@ _SERVE_TESTING = topic == "serve" _POINTCLOUD_TESTING = topic == "pointcloud" _GRAPH_TESTING = topic == "graph" + _AUDIO_TESTING = topic == "audio" diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 6a80b5774a..87cb183504 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -168,7 +168,7 @@ def test_from_filepaths_visualise(tmpdir): dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir)