Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Audio data sources + Numpy file support (#651)
Browse files Browse the repository at this point in the history
* Initial commit

* Fixes

* Drop asteroid

* Drop asteroid

* Try fix

* Speed improvements

* Updates

* Fixes

* Updates

* Updates

* Updates

* Updates

* Updates

* Fixes

* Debug

* Debug

* Fixes

* Fixes

* Docstrings

* Fixes

* Fixes

* CHANGELOG.md
  • Loading branch information
ethanwharris authored Aug 13, 2021
1 parent ddd942d commit 6596669
Show file tree
Hide file tree
Showing 18 changed files with 338 additions and 219 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ logs/cache/*
flash_examples/data
flash_examples/cli/*/data
timit/
urban8k_images/
__MACOSX
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added Flash Zero, a zero code command line ML platform built with flash ([#611](https://github.com/PyTorchLightning/lightning-flash/pull/611))

- Added support for `.npy` and `.npz` files to `ImageClassificationData` and `AudioClassificationData` ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Added support for `from_csv` to the `AudioClassificationData` ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Added option to pass a `resolver` to the `from_csv` and `from_pandas` methods of `ImageClassificationData`, which is used to resolve filenames given IDs ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

- Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

- Changed the behaviour of the `sampler` argument of the `DataModule` to take a `Sampler` type rather than instantiated object ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

### Fixed

- Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493))
Expand All @@ -50,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where some tasks were not compatible with PyTorch 1.7 due to use of `torch.jit.isinstance` ([#611](https://github.com/PyTorchLightning/lightning-flash/pull/611))

- Fixed a bug where custom samplers would not be properly forwarded to the data loader ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

## [0.4.0] - 2021-06-22

### Added
Expand Down
62 changes: 44 additions & 18 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,46 @@
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Tuple

import numpy as np

from flash.audio.classification.transforms import default_transforms, train_default_transforms
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.data_source import (
DefaultDataSources,
has_file_allowed_extension,
LoaderDataFrameDataSource,
PathsDataSource,
)
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import requires_extras
from flash.image.classification.data import MatplotlibVisualization
from flash.image.data import ImageDeserializer, ImagePathsDataSource
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, requires_extras
from flash.image.classification.data import ImageClassificationData
from flash.image.data import ImageDeserializer

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS


NP_EXTENSIONS = (".npy", ".npz")


def spectrogram_loader(filepath: str):
if has_file_allowed_extension(filepath, IMG_EXTENSIONS):
img = default_loader(filepath)
data = np.array(img)
else:
data = np.load(filepath)
return data


class AudioClassificationPathsDataSource(PathsDataSource):
@requires_extras("image")
def __init__(self):
super().__init__(loader=spectrogram_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS)


class AudioClassificationDataFrameDataSource(LoaderDataFrameDataSource):
@requires_extras("image")
def __init__(self):
super().__init__(spectrogram_loader)


class AudioClassificationPreprocess(Preprocess):
Expand All @@ -31,7 +63,7 @@ def __init__(
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
spectrogram_size: Tuple[int, int] = (196, 196),
spectrogram_size: Tuple[int, int] = (128, 128),
time_mask_param: int = 80,
freq_mask_param: int = 80,
deserializer: Optional["Deserializer"] = None,
Expand All @@ -46,8 +78,10 @@ def __init__(
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
DefaultDataSources.FILES: AudioClassificationPathsDataSource(),
DefaultDataSources.FOLDERS: AudioClassificationPathsDataSource(),
"data_frame": AudioClassificationDataFrameDataSource(),
DefaultDataSources.CSV: AudioClassificationDataFrameDataSource(),
},
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
Expand All @@ -72,15 +106,7 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param)


class AudioClassificationData(DataModule):
class AudioClassificationData(ImageClassificationData):
"""Data module for audio classification."""

preprocess_cls = AudioClassificationPreprocess

def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return MatplotlibVisualization(*args, **kwargs)
7 changes: 4 additions & 3 deletions flash/audio/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import torch
from torch import nn
from torch.utils.data._utils.collate import default_collate

from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms
from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
Expand All @@ -32,12 +33,12 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]
"""The default transforms for audio classification for spectrograms: resize the spectrogram, convert the
spectrogram and target to a tensor, and collate the batch."""
return {
"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)),
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
"collate": kornia_collate,
"post_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)),
"collate": default_collate,
}


Expand Down
56 changes: 36 additions & 20 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,20 @@
# limitations under the License.
import os
import platform
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
Callable,
Collection,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

import numpy as np
import pytorch_lightning as pl
Expand Down Expand Up @@ -86,7 +99,7 @@ def __init__(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
) -> None:

super().__init__()
Expand Down Expand Up @@ -281,7 +294,10 @@ def _train_dataloader(self) -> DataLoader:
pin_memory = True

if self.sampler is None:
sampler = None
shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset))
else:
sampler = self.sampler(train_ds)

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_train_dataset(
Expand All @@ -292,14 +308,14 @@ def _train_dataloader(self) -> DataLoader:
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
sampler=self.sampler,
sampler=sampler,
)

return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=self.sampler,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
Expand Down Expand Up @@ -453,7 +469,7 @@ def from_data_source(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given inputs to
Expand Down Expand Up @@ -489,7 +505,7 @@ def from_data_source(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -553,7 +569,7 @@ def from_folders(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the
Expand Down Expand Up @@ -582,7 +598,7 @@ def from_folders(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -636,7 +652,7 @@ def from_files(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given sequences of files
Expand Down Expand Up @@ -668,7 +684,7 @@ def from_files(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -723,7 +739,7 @@ def from_tensors(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given tensors using the
Expand Down Expand Up @@ -755,7 +771,7 @@ def from_tensors(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -810,7 +826,7 @@ def from_numpy(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given numpy array using the
Expand Down Expand Up @@ -842,7 +858,7 @@ def from_numpy(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -896,7 +912,7 @@ def from_json(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
field: Optional[str] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -928,7 +944,7 @@ def from_json(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
field: To specify the field that holds the data in the JSON file.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -1006,7 +1022,7 @@ def from_csv(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the
Expand Down Expand Up @@ -1037,7 +1053,7 @@ def from_csv(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -1090,7 +1106,7 @@ def from_datasets(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the
Expand Down Expand Up @@ -1119,7 +1135,7 @@ def from_datasets(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down
Loading

0 comments on commit 6596669

Please sign in to comment.