diff --git a/example/gluon/urban_sounds/README.md b/example/gluon/urban_sounds/README.md new file mode 100644 index 000000000000..f7e33136092f --- /dev/null +++ b/example/gluon/urban_sounds/README.md @@ -0,0 +1,22 @@ +# Urban Sounds classification in MXNet + +Urban Sounds Dataset: +## Description + The dataset contains 8732 wav files which are audio samples(<= 4s)) of street sounds like engine_idling, car_horn, children_playing, dog_barking and so on. + The task is to classify these audio samples into one of the 10 labels. + +To be able to run this example: + +1. Download the dataset(train.zip, test.zip) required for this example from the location: +**https://drive.google.com/drive/folders/0By0bAi7hOBAFUHVXd1JCN3MwTEU** + + +2. Extract both the zip archives into the **current directory** - after unzipping you would get 2 new folders namely,\ + **Train** and **Test** and two csv files - **train.csv**, **test.csv** + +3. Apache MXNet is installed on the machine. For instructions, go to the link: **https://mxnet.incubator.apache.org/install/** + +4. Librosa is installed. To install, use the commands + `pip install librosa`, + For more details, refer here: + **https://librosa.github.io/librosa/install.html** diff --git a/example/gluon/urban_sounds/__init__.py b/example/gluon/urban_sounds/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/example/gluon/urban_sounds/model.py b/example/gluon/urban_sounds/model.py new file mode 100644 index 000000000000..3b3c3500c2bb --- /dev/null +++ b/example/gluon/urban_sounds/model.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" + This module builds a model an MLP with a configurable output layer( number of units in the last layer). + Users can pass any number of units in the last layer. SInce this dataset has 10 labels, + the default value of num_labels = 10 +""" +import mxnet as mx +from mxnet import gluon + +# Defining a neural network with number of labels +def get_net(num_labels=10): + net = gluon.nn.Sequential() + with net.name_scope(): + net.add(gluon.nn.Dense(256, activation="relu")) # 1st layer (256 nodes) + net.add(gluon.nn.Dense(256, activation="relu")) # 2nd hidden layer + net.add(gluon.nn.Dense(num_labels)) + net.collect_params().initialize(mx.init.Normal(1.)) + return net diff --git a/example/gluon/urban_sounds/predict.py b/example/gluon/urban_sounds/predict.py new file mode 100644 index 000000000000..7adfa9451489 --- /dev/null +++ b/example/gluon/urban_sounds/predict.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Prediction module for Urban Sounds Classification +""" +import os +import warnings +import mxnet as mx +from mxnet import nd +from mxnet.gluon.contrib.data.audio.transforms import MFCC +from model import get_net + +def predict(prediction_dir='./Test'): + """The function is used to run predictions on the audio files in the directory `pred_directory`. + + Parameters + ---------- + net: + The model that has been trained. + prediction_dir: string, default ./Test + The directory that contains the audio files on which predictions are to be made + + """ + + try: + import librosa + except ImportError: + warnings.warn("Librosa is not installed! please run the following command pip install librosa.") + return + + if not os.path.exists(prediction_dir): + warnings.warn("The directory on which predictions are to be made is not found!") + return + + if len(os.listdir(prediction_dir)) == 0: + warnings.warn("The directory on which predictions are to be made is empty! Exiting...") + return + + # Loading synsets + if not os.path.exists('./synset.txt'): + warnings.warn("The synset or labels for the dataset do not exist. Please run the training script first.") + return + + with open("./synset.txt", "r") as f: + synset = [l.rstrip() for l in f] + net = get_net(len(synset)) + print("Trying to load the model with the saved parameters...") + if not os.path.exists("./net.params"): + warnings.warn("The model does not have any saved parameters... Cannot proceed! Train the model first") + return + + net.load_parameters("./net.params") + file_names = os.listdir(prediction_dir) + full_file_names = [os.path.join(prediction_dir, item) for item in file_names] + mfcc = MFCC() + print("\nStarting predictions for audio files in ", prediction_dir, " ....\n") + for filename in full_file_names: + # Argument kaiser_fast to res_type is faster than 'kaiser_best'. To reduce the load time, passing kaiser_fast. + X1, _ = librosa.load(filename, res_type='kaiser_fast') + transformed_test_data = mfcc(mx.nd.array(X1)) + output = net(transformed_test_data.reshape((1, -1))) + prediction = nd.argmax(output, axis=1) + print(filename, " -> ", synset[(int)(prediction.asscalar())]) + + +if __name__ == '__main__': + try: + import argparse + parser = argparse.ArgumentParser(description="Urban Sounds clsssification example - MXNet") + parser.add_argument('--pred', '-p', help="Enter the folder path that contains your audio files", type=str) + args = parser.parse_args() + pred_dir = args.pred + + except ImportError: + warnings.warn("Argparse module not installed! passing default arguments.") + pred_dir = './Test' + predict(prediction_dir=pred_dir) + print("Urban sounds classification Prediction DONE!") diff --git a/example/gluon/urban_sounds/train.py b/example/gluon/urban_sounds/train.py new file mode 100644 index 000000000000..7626721c4d15 --- /dev/null +++ b/example/gluon/urban_sounds/train.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The module to run training on the Urban sounds dataset +""" +import os +import time +import warnings +import mxnet as mx +from mxnet import gluon, nd, autograd +from mxnet.gluon.contrib.data.audio.datasets import AudioFolderDataset +from mxnet.gluon.contrib.data.audio.transforms import MFCC +import model + +def evaluate_accuracy(data_iterator, net): + """Function to evaluate accuracy of any data iterator passed to it as an argument""" + acc = mx.metric.Accuracy() + for _, (data, label) in enumerate(data_iterator): + output = net(data) + predictions = nd.argmax(output, axis=1) + predictions = predictions.reshape((-1, 1)) + acc.update(preds=predictions, labels=label) + return acc.get()[1] + + +def train(train_dir=None, train_csv=None, epochs=30, batch_size=32): + """The function responsible for running the training the model.""" + try: + import librosa + except ImportError: + warnings.warn("The dependency librosa is not installed. Cannot continue") + return + if not train_dir or not os.path.exists(train_dir) or not train_csv: + warnings.warn("No train directory could be found ") + return + # Make a dataset from the local folder containing Audio data + print("\nMaking an Audio Dataset...\n") + tick = time.time() + aud_dataset = AudioFolderDataset(train_dir, train_csv=train_csv, file_format='.wav', skip_rows=1) + tock = time.time() + + print("Loading the dataset took ", (tock-tick), " seconds.") + print("\n=======================================\n") + print("Number of output classes = ", len(aud_dataset.synsets)) + print("\nThe labels are : \n") + print(aud_dataset.synsets) + # Get the model to train + net = model.get_net(len(aud_dataset.synsets)) + print("\nNeural Network = \n") + print(net) + print("\nModel - Neural Network Generated!\n") + print("=======================================\n") + + #Define the loss - Softmax CE Loss + softmax_loss = gluon.loss.SoftmaxCELoss(from_logits=False, sparse_label=True) + print("Loss function initialized!\n") + print("=======================================\n") + + #Define the trainer with the optimizer + trainer = gluon.Trainer(net.collect_params(), 'adadelta') + print("Optimizer - Trainer function initialized!\n") + print("=======================================\n") + print("Loading the dataset to the Gluon's OOTB Dataloader...") + + #Getting the data loader out of the AudioDataset and passing the transform + aud_transform = MFCC() + tick = time.time() + + audio_train_loader = gluon.data.DataLoader(aud_dataset.transform_first(aud_transform), batch_size=32, shuffle=True) + tock = time.time() + print("Time taken to load data and apply transform here is ", (tock-tick), " seconds.") + print("=======================================\n") + + + print("Starting the training....\n") + # Training loop + tick = time.time() + batch_size = batch_size + num_examples = len(aud_dataset) + + for e in range(epochs): + cumulative_loss = 0 + for _, (data, label) in enumerate(audio_train_loader): + with autograd.record(): + output = net(data) + loss = softmax_loss(output, label) + loss.backward() + + trainer.step(batch_size) + cumulative_loss += mx.nd.sum(loss).asscalar() + + if e%5 == 0: + train_accuracy = evaluate_accuracy(audio_train_loader, net) + print("Epoch %s. Loss: %s Train accuracy : %s " % (e, cumulative_loss/num_examples, train_accuracy)) + print("\n------------------------------\n") + + train_accuracy = evaluate_accuracy(audio_train_loader, net) + tock = time.time() + print("\nFinal training accuracy: ", train_accuracy) + + print("Training the sound classification for ", epochs, " epochs, MLP model took ", (tock-tick), " seconds") + print("====================== END ======================\n") + + print("Trying to save the model parameters here...") + net.save_parameters("./net.params") + print("Saved the model parameters in current directory.") + + +if __name__ == '__main__': + + try: + import argparse + parser = argparse.ArgumentParser(description="Urban Sounds clsssification example - MXNet") + parser.add_argument('--train', '-t', help="Enter the folder path that contains your audio files", type=str) + parser.add_argument('--csv', '-c', help="Enter the filename of the csv that contains filename\ + to label mapping", type=str) + parser.add_argument('--epochs', '-e', help="Enter the number of epochs \ + you would want to run the training for.", type=int) + parser.add_argument('--batch_size', '-b', help="Enter the batch_size of data", type=int) + args = parser.parse_args() + + if args: + if args.train: + training_dir = args.train + else: + training_dir = './Train' + + if args.csv: + training_csv = args.csv + else: + training_csv = './train.csv' + + if args.epochs: + eps = args.epochs + else: + eps = 30 + + if args.batch_size: + batch_sz = args.batch_size + else: + batch_sz = 32 + + except ImportError as er: + warnings.warn("Argument parsing module could not be imported \ + Passing default arguments.") + training_dir = './Train' + training_csv = './train.csv' + eps = 30 + batch_sz = 32 + + train(train_dir=training_dir, train_csv=training_csv, epochs=eps, batch_size=batch_sz) + print("Urban sounds classification Training DONE!") diff --git a/python/mxnet/gluon/contrib/data/__init__.py b/python/mxnet/gluon/contrib/data/__init__.py index 7cb25eb7498e..bc2f09194864 100644 --- a/python/mxnet/gluon/contrib/data/__init__.py +++ b/python/mxnet/gluon/contrib/data/__init__.py @@ -22,3 +22,5 @@ from . import text from .sampler import * + +from . import audio diff --git a/python/mxnet/gluon/contrib/data/audio/__init__.py b/python/mxnet/gluon/contrib/data/audio/__init__.py new file mode 100644 index 000000000000..ef4e5e11bcff --- /dev/null +++ b/python/mxnet/gluon/contrib/data/audio/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Audio utilities.""" diff --git a/python/mxnet/gluon/contrib/data/audio/datasets.py b/python/mxnet/gluon/contrib/data/audio/datasets.py new file mode 100644 index 000000000000..d29cbcaca078 --- /dev/null +++ b/python/mxnet/gluon/contrib/data/audio/datasets.py @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable= +""" Audio Dataset container.""" +__all__ = ['AudioFolderDataset'] + +import os +import warnings +from ....data import Dataset +from ..... import ndarray as nd +try: + import librosa +except ImportError as e: + warnings.warn("gluon/contrib/data/audio/datasets.py : librosa dependency could not be resolved or \ + imported, could not load audio onto the numpy array.") + + +class AudioFolderDataset(Dataset): + """A dataset for loading Audio files stored in a folder structure like:: + + root/children_playing/0.wav + root/siren/23.wav + root/drilling/26.wav + root/dog_barking/42.wav + OR + Files(wav) and a csv file that has filename and associated label + + Parameters + ---------- + root : str + Path to root directory. + transform : callable, default None + A function that takes data and label and transforms them + train_csv: str, default None + train_csv should be populated by the training csv filename + file_format: str, default '.wav' + The format of the audio files(.wav, .mp3) + skip_rows: int, default 0 + While reading from csv file, how many rows to skip at the start of the file to avoid reading in header + + Attributes + ---------- + synsets : list + List of class names. `synsets[i]` is the name for the integer label `i` + items : list of tuples + List of all audio in (filename, label) pairs. + """ + def __init__(self, root, train_csv=None, file_format='.wav', skip_rows=0): + if not librosa: + warnings.warn("pip install librosa to continue.") + return + self._root = os.path.expanduser(root) + self._exts = ['.wav'] + self._format = file_format + self._train_csv = train_csv + if file_format.lower() not in self._exts: + warnings.warn("format {} not supported currently.".format(file_format)) + return + self._list_audio_files(self._root, skip_rows=skip_rows) + + + def _list_audio_files(self, root, skip_rows=0): + """ + Populates synsets - a map of index to label for the data items. + Populates the data in the dataset, making tuples of (data, label) + """ + self.synsets = [] + self.items = [] + if self._train_csv is None: + for folder in sorted(os.listdir(root)): + path = os.path.join(root, folder) + if not os.path.isdir(path): + warnings.warn('Ignoring %s, which is not a directory.'%path, stacklevel=3) + continue + label = len(self.synsets) + self.synsets.append(folder) + for filename in sorted(os.listdir(path)): + file_name = os.path.join(path, filename) + ext = os.path.splitext(file_name)[1] + if ext.lower() not in self._exts: + warnings.warn('Ignoring %s of type %s. Only support %s'%(filename, ext, ', '.join(self._exts))) + continue + self.items.append((file_name, label)) + else: + data_tmp = [] + label_tmp = [] + skipped_rows = 0 + with open(self._train_csv, "r") as traincsv: + for line in traincsv: + skipped_rows = skipped_rows + 1 + if skipped_rows <= skip_rows: + continue + filename = os.path.join(root, line.split(",")[0]) + label = line.split(",")[1].strip() + if label not in self.synsets: + self.synsets.append(label) + data_tmp.append(os.path.join(self._root, line.split(",")[0])) + label_tmp.append(self.synsets.index(label)) + + #Generating the synset.txt file now + with open("./synset.txt", "w") as synsets_file: + for item in self.synsets: + synsets_file.write(item+os.linesep) + print("Synsets is generated as synset.txt") + + self._label = nd.array(label_tmp) + for i, _ in enumerate(data_tmp): + if self._format not in data_tmp[i]: + self.items.append((data_tmp[i]+self._format, self._label[i])) + + def __getitem__(self, idx): + """ + Retrieve the item (data, label) stored at idx in items + """ + filename = self.items[idx][0] + label = self.items[idx][1] + + if librosa is not None: + X1, _ = librosa.load(filename, res_type='kaiser_fast') + return nd.array(X1), label + + else: + warnings.warn(" Dependency librosa is not installed! \ + Cannot load the audio(wav) file into the numpy.ndarray.") + return self.items[idx][0], self.items[idx][1] + + def __len__(self): + """ + Retrieves the number of items in the dataset + """ + return len(self.items) + + + def transform_first(self, fn, lazy=True): + """Returns a new dataset with the first element of each sample + transformed by the transformer function `fn`. + + This is useful, for example, when you only want to transform data + while keeping label as is. + + Parameters + ---------- + fn : callable + A transformer function that takes the first elemtn of a sample + as input and returns the transformed element. + lazy : bool, default True + If False, transforms all samples at once. Otherwise, + transforms each sample on demand. Note that if `fn` + is stochastic, you must set lazy to True or you will + get the same result on all epochs. + + Returns + ------- + Dataset + The transformed dataset. + """ + return super(AudioFolderDataset, self).transform_first(fn, lazy=False) diff --git a/python/mxnet/gluon/contrib/data/audio/transforms.py b/python/mxnet/gluon/contrib/data/audio/transforms.py new file mode 100644 index 000000000000..486d88d4a363 --- /dev/null +++ b/python/mxnet/gluon/contrib/data/audio/transforms.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable= arguments-differ +"Audio transforms." + +import warnings +import numpy as np +try: + import librosa +except ImportError as e: + warnings.warn("gluon/contrib/data/audio/transforms.py : librosa dependency could not be resolved or \ + imported, could not provide some/all transform.") + +from ..... import ndarray as nd +from ....block import Block + +class MFCC(Block): + """Extracts Mel frequency cepstrum coefficients from the audio data file + More details : https://librosa.github.io/librosa/generated/librosa.feature.mfcc.html + + Attributes + ---------- + sampling_rate: int, default 22050 + sampling rate of the input audio signal + num_mfcc: int, default 20 + number of mfccs to return + + + Inputs: + - **x**: input tensor (samples, ) shape. + + Outputs: + - **out**: output array is a scaled NDArray with (samples, ) shape. + + """ + + def __init__(self, sampling_rate=22050, num_mfcc=20): + self._sampling_rate = sampling_rate + self._num_fcc = num_mfcc + super(MFCC, self).__init__() + + def forward(self, x): + if not librosa: + warnings.warn("Librosa dependency is not installed! Install that and retry") + return x + if isinstance(x, np.ndarray): + y = x + elif isinstance(x, nd.NDArray): + y = x.asnumpy() + else: + warnings.warn("MFCC - allowed datatypes mx.nd.NDArray and numpy.ndarray") + return x + + audio_tmp = np.mean(librosa.feature.mfcc(y=y, sr=self._sampling_rate, n_mfcc=self._num_fcc).T, axis=0) + return nd.array(audio_tmp) + + +class Scale(Block): + """Scale audio numpy.ndarray from a 16-bit integer to a floating point number between + -1.0 and 1.0. The 16-bit integer is the sample resolution or bit depth. + + Attributes + ---------- + scale_factor : float + The factor to scale the input tensor by. + + + Inputs: + - **x**: input tensor (samples, ) shape. + + Outputs: + - **out**: output array is a scaled NDArray with (samples, ) shape. + + Examples + -------- + >>> scale = audio.transforms.Scale(scale_factor=2) + >>> audio_samples = mx.nd.array([2,3,4]) + >>> scale(audio_samples) + [1. 1.5 2. ] + + + """ + + def __init__(self, scale_factor=2**31): + self.scale_factor = scale_factor + super(Scale, self).__init__() + + def forward(self, x): + if isinstance(x, np.ndarray): + return nd.array(x/self.scale_factor) + return x / self.scale_factor + + +class PadTrim(Block): + """Pad/Trim a 1d-NDArray of NPArray (Signal or Labels) + + Attributes + ---------- + max_len : int + Length to which the array will be padded or trimmed to. + fill_value: int or float + If there is a need of padding, what value to padd at the end of the input array + + + Inputs: + - **x**: input tensor (samples, ) shape. + + Outputs: + - **out**: output array is a scaled NDArray with (max_len, ) shape. + + Examples + -------- + >>> padtrim = audio.transforms.PadTrim(max_len=9, fill_value=0) + >>> audio_samples = mx.nd.array([1,2,3,4,5]) + >>> padtrim(audio_samples) + [1. 2. 3. 4. 5. 0. 0. 0. 0.] + + + """ + + def __init__(self, max_len, fill_value=0): + self._max_len = max_len + self._fill_value = fill_value + super(PadTrim, self).__init__() + + def forward(self, x): + if isinstance(x, np.ndarray): + x = nd.array(x) + if self._max_len > x.size: + pad = nd.ones((self._max_len - x.size,)) * self._fill_value + x = nd.concat(x, pad, dim=0) + elif self._max_len < x.size: + x = x[:self._max_len] + return x + + +class MEL(Block): + """Create MEL Spectrograms from a raw audio signal. Relatively pretty slow. + + Attributes + ---------- + sampling_rate: int, default 22050 + sampling rate of the input audio signal + num_fft: int, default 2048 + length of the Fast fourier transform window + num_mels: int, default 20 + number of mel bands to generate + hop_length: int, default 512 + total samples between successive frames + + + Inputs: + - **x**: input tensor (samples, ) shape. + + Outputs: + - **out**: output array which consists of mel spectograms, shape = (n_mels, 1) + + Usage (see librosa.feature.melspectrogram docs): + MEL(sr=16000, n_fft=1600, hop_length=800, n_mels=64) + + Examples + -------- + >>> mel = audio.transforms.MEL() + >>> audio_samples = mx.nd.array([1,2,3,4,5]) + >>> mel(audio_samples) + [[3.81801406e+04] + [9.86858240e-29] + [1.87405472e-29] + [2.38637225e-29] + [3.94043010e-29] + [3.67071565e-29] + [7.29390295e-29] + [8.84324438e-30]... + + + """ + + def __init__(self, sampling_rate=22050, num_fft=2048, num_mels=20, hop_length=512): + self._sampling_rate = sampling_rate + self._num_fft = num_fft + self._num_mels = num_mels + self._hop_length = hop_length + super(MEL, self).__init__() + + def forward(self, x): + if librosa is None: + warnings.warn("Cannot create spectrograms, since dependency librosa is not installed!") + return x + if isinstance(x, nd.NDArray): + x = x.asnumpy() + specs = librosa.feature.melspectrogram(x, sr=self._sampling_rate,\ + n_fft=self._num_fft, n_mels=self._num_mels, hop_length=self._hop_length) + return nd.array(specs) + \ No newline at end of file diff --git a/tests/python/unittest/test_contrib_gluon_data_audio.py b/tests/python/unittest/test_contrib_gluon_data_audio.py new file mode 100644 index 000000000000..be757f7143b2 --- /dev/null +++ b/tests/python/unittest/test_contrib_gluon_data_audio.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Testing audio transforms in gluon container.""" +from __future__ import print_function +import warnings +import numpy as np +from mxnet import gluon +from mxnet.gluon.contrib.data.audio import transforms +from mxnet.test_utils import assert_almost_equal +from common import with_seed + + +@with_seed() +def test_pad_trim(): + """ + Function to test Pad/Trim Audio transform + """ + data_in = np.random.randint(1, high=20, size=(15)) + # trying trimming the audio samples here... + max_len = 10 + pad_trim = gluon.contrib.data.audio.transforms.PadTrim(max_len=max_len) + trimmed_audio = pad_trim(data_in) + np_trimmed = data_in[:max_len] + assert_almost_equal(trimmed_audio.asnumpy(), np_trimmed) + + #trying padding here... + max_len = 25 + fill_value = 0 + pad_trim = transforms.PadTrim(max_len=max_len, fill_value=fill_value) + np_padded = np.pad(data_in, pad_width=max_len-len(data_in), mode='constant', \ + constant_values=fill_value)[max_len-len(data_in):] + padded_audio = pad_trim(data_in) + assert_almost_equal(padded_audio.asnumpy(), np_padded) + + +@with_seed() +def test_scale(): + """ + Function to test scaling of the audio transform + """ + data_in = np.random.randint(1, high=20, size=(15)) + # Scaling the audio signal meaning dividing each sample by the scaling factor + scale_factor = 2.0 + scaled_numpy = data_in /scale_factor + scale = transforms.Scale(scale_factor=scale_factor) + scaled_audio = scale(data_in) + assert_almost_equal(scaled_audio.asnumpy(), scaled_numpy) + + +@with_seed() +def test_mfcc(): + """ + Function to test extraction of mfcc from audio signal + """ + try: + import librosa + except ImportError: + warnings.warn("Librosa not installed! pip install librosa and then continue.") + return + audio_samples = np.random.rand(20) + n_mfcc = 64 + mfcc = gluon.contrib.data.audio.transforms.MFCC(num_mfcc=n_mfcc) + + mfcc_features = mfcc(audio_samples) + assert mfcc_features.shape[0] == n_mfcc + + +@with_seed() +def test_mel(): + """ + Function to test extraction of MEL spectrograms from audio signal + """ + try: + import librosa + except ImportError: + warnings.warn("Librosa not installed! pip install librosa and then continue.") + return + audio_samples = np.random.rand(20) + n_mels = 256 + mel = gluon.contrib.data.audio.transforms.MEL(num_mels=n_mels) + + mels = mel(audio_samples) + assert mels.shape[0] == n_mels + +if __name__ == '__main__': + import nose + nose.runmodule()