diff --git a/example/gluon/audio/transforms.py b/example/gluon/audio/transforms.py new file mode 100644 index 000000000000..8b76d131cdb1 --- /dev/null +++ b/example/gluon/audio/transforms.py @@ -0,0 +1,205 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# coding: utf-8 +# pylint: disable= arguments-differ +"""Audio transforms.""" + +import warnings +import numpy as np +try: + import librosa +except ImportError as e: + warnings.warn("librosa dependency could not be resolved or \ + imported, could not provide some/all transform.") + +from mxnet import ndarray as nd +from mxnet.gluon.block import Block + +class MFCC(Block): + """Extracts Mel frequency cepstrum coefficients from the audio data file + More details : https://librosa.github.io/librosa/generated/librosa.feature.mfcc.html + + Attributes + ---------- + sampling_rate: int, default 22050 + sampling rate of the input audio signal + num_mfcc: int, default 20 + number of mfccs to return + + + Inputs: + - **x**: input tensor (samples, ) shape. + + Outputs: + - **out**: output array is a scaled NDArray with (samples, ) shape. + + """ + + def __init__(self, sampling_rate=22050, num_mfcc=20): + self._sampling_rate = sampling_rate + self._num_fcc = num_mfcc + super(MFCC, self).__init__() + + def forward(self, x): + if isinstance(x, np.ndarray): + y = x + elif isinstance(x, nd.NDArray): + y = x.asnumpy() + else: + warnings.warn("MFCC - allowed datatypes mx.nd.NDArray and numpy.ndarray") + return x + + audio_tmp = np.mean(librosa.feature.mfcc(y=y, sr=self._sampling_rate, n_mfcc=self._num_fcc).T, axis=0) + return nd.array(audio_tmp) + + +class Scale(Block): + """Scale audio numpy.ndarray from a 16-bit integer to a floating point number between + -1.0 and 1.0. The 16-bit integer is the sample resolution or bit depth. + + Attributes + ---------- + scale_factor : float + The factor to scale the input tensor by. + + + Inputs: + - **x**: input tensor (samples, ) shape. + + Outputs: + - **out**: output array is a scaled NDArray with (samples, ) shape. + + Examples + -------- + >>> scale = audio.transforms.Scale(scale_factor=2) + >>> audio_samples = mx.nd.array([2,3,4]) + >>> scale(audio_samples) + [1. 1.5 2. ] + + + """ + + def __init__(self, scale_factor=2**31): + self.scale_factor = scale_factor + super(Scale, self).__init__() + + def forward(self, x): + if self.scale_factor == 0: + warnings.warn("Scale factor cannot be 0.") + return x + if isinstance(x, np.ndarray): + return nd.array(x/self.scale_factor) + return x / self.scale_factor + + +class PadTrim(Block): + """Pad/Trim a 1d-NDArray of NPArray (Signal or Labels) + + Attributes + ---------- + max_len : int + Length to which the array will be padded or trimmed to. + fill_value: int or float + If there is a need of padding, what value to pad at the end of the input array. + + + Inputs: + - **x**: input tensor (samples, ) shape. + + Outputs: + - **out**: output array is a scaled NDArray with (max_len, ) shape. + + Examples + -------- + >>> padtrim = audio.transforms.PadTrim(max_len=9, fill_value=0) + >>> audio_samples = mx.nd.array([1,2,3,4,5]) + >>> padtrim(audio_samples) + [1. 2. 3. 4. 5. 0. 0. 0. 0.] + + + """ + + def __init__(self, max_len, fill_value=0): + self._max_len = max_len + self._fill_value = fill_value + super(PadTrim, self).__init__() + + def forward(self, x): + if isinstance(x, np.ndarray): + x = nd.array(x) + if self._max_len > x.size: + pad = nd.ones((self._max_len - x.size,)) * self._fill_value + x = nd.concat(x, pad, dim=0) + elif self._max_len < x.size: + x = x[:self._max_len] + return x + + +class MEL(Block): + """Create MEL Spectrograms from a raw audio signal. Relatively pretty slow. + + Attributes + ---------- + sampling_rate: int, default 22050 + sampling rate of the input audio signal + num_fft: int, default 2048 + length of the Fast Fourier transform window + num_mels: int, default 20 + number of mel bands to generate + hop_length: int, default 512 + total samples between successive frames + + + Inputs: + - **x**: input tensor (samples, ) shape. + + Outputs: + - **out**: output array which consists of mel spectograms, shape = (n_mels, 1) + + Usage (see librosa.feature.melspectrogram docs): + MEL(sr=16000, n_fft=1600, hop_length=800, n_mels=64) + + Examples + -------- + >>> mel = audio.transforms.MEL() + >>> audio_samples = mx.nd.array([1,2,3,4,5]) + >>> mel(audio_samples) + [[3.81801406e+04] + [9.86858240e-29] + [1.87405472e-29] + [2.38637225e-29] + [3.94043010e-29] + [3.67071565e-29] + [7.29390295e-29] + [8.84324438e-30]... + + + """ + + def __init__(self, sampling_rate=22050, num_fft=2048, num_mels=20, hop_length=512): + self._sampling_rate = sampling_rate + self._num_fft = num_fft + self._num_mels = num_mels + self._hop_length = hop_length + super(MEL, self).__init__() + + def forward(self, x): + if isinstance(x, nd.NDArray): + x = x.asnumpy() + specs = librosa.feature.melspectrogram(x, sr=self._sampling_rate,\ + n_fft=self._num_fft, n_mels=self._num_mels, hop_length=self._hop_length) + return nd.array(specs) diff --git a/example/gluon/audio/urban_sounds/README.md b/example/gluon/audio/urban_sounds/README.md new file mode 100644 index 000000000000..c85d29db2e5a --- /dev/null +++ b/example/gluon/audio/urban_sounds/README.md @@ -0,0 +1,100 @@ +# Urban Sounds Classification in MXNet Gluon + +This example provides an end-to-end pipeline for a common datahack competition - Urban Sounds Classification Example. +Below is the link to the competition: +https://datahack.analyticsvidhya.com/contest/practice-problem-urban-sound-classification/ + +After logging in, the data set can be downloaded. +The details of the dataset and the link to download it are given below: + + +## Urban Sounds Dataset: +### Description + The dataset contains 8732 wav files which are audio samples(<= 4s)) of street sounds like engine_idling, car_horn, children_playing, dog_barking and so on. + The task is to classify these audio samples into one of the following 10 labels: + ``` + siren, + street_music, + drilling, + dog_bark, + children_playing, + gun_shot, + engine_idling, + air_conditioner, + jackhammer, + car_horn + ``` + +To be able to run this example: + +1. `pip install -r requirements.txt` + + If you are in the directory where the requirements.txt file lies, + this step installs the required libraries to run the example. + The main dependency that is required is: Librosa. + The version used to test the example is: `0.6.2` + For more details, refer here: +https://librosa.github.io/librosa/install.html + +2. Download the dataset(train.zip, test.zip) required for this example from the location: +https://drive.google.com/drive/folders/0By0bAi7hOBAFUHVXd1JCN3MwTEU + +3. Extract both the zip archives into the **current directory** - after unzipping you would get 2 new folders namely, + **Train** and **Test** and two csv files - **train.csv**, **test.csv** + + Assuming you are in a directory *"UrbanSounds"*, after downloading and extracting train.zip, the folder structure should be: + + ``` + UrbanSounds + - Train + - 0.wav, 1.wav ... + - train.csv + - train.py + - predict.py ... + ``` + +4. Apache MXNet is installed on the machine. For instructions, go to the link: https://mxnet.incubator.apache.org/install/ + + + +For information on the current design of how the AudioFolderDataset is implemented, refer below: +https://cwiki.apache.org/confluence/display/MXNET/Gluon+-+Audio + +### Usage + +For training: + +- Arguments + - train : The folder/directory that contains the audio(wav) files locally. Default = "./Train" + - csv: The file name of the csv file that contains audio file name to label mapping. Default = "train.csv" + - epochs : Number of epochs to train the model. Default = 30 + - batch_size : The batch size for training. Default = 32 + + +###### To use the default arguments, use: +``` +python train.py +``` +or + +###### To pass command-line arguments for training data directory, epochs, batch_size, csv file name, use : +``` +python train.py --train ./Train --csv train.csv --batch_size 32 --epochs 30 +``` + +For prediction: + +- Arguments + - pred : The folder/directory that contains the audio(wav) files which are to be classified. Default = "./Test" + + +###### To use the default arguments, use: +``` +python predict.py +``` +or + +###### To pass command-line arguments for test data directory, use : +``` +python predict.py --pred ./Test +``` \ No newline at end of file diff --git a/example/gluon/audio/urban_sounds/datasets.py b/example/gluon/audio/urban_sounds/datasets.py new file mode 100644 index 000000000000..51c040c8f162 --- /dev/null +++ b/example/gluon/audio/urban_sounds/datasets.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable= +""" Audio Dataset container.""" +from __future__ import print_function +__all__ = ['AudioFolderDataset'] + +import os +import warnings +from itertools import islice +import csv +from mxnet.gluon.data import Dataset +from mxnet import ndarray as nd +try: + import librosa +except ImportError as e: + raise ImportError("librosa dependency could not be resolved or \ + imported, could not load audio onto the numpy array. pip install librosa") + + + +class AudioFolderDataset(Dataset): + """A dataset for loading Audio files stored in a folder structure like:: + + root/children_playing/0.wav + root/siren/23.wav + root/drilling/26.wav + root/dog_barking/42.wav + OR + Files(wav) and a csv file that has file name and associated label + + Parameters + ---------- + root : str + Path to root directory. + transform : callable, default None + A function that takes data and label and transforms them + train_csv: str, default None + train_csv should be populated by the training csv filename + file_format: str, default '.wav' + The format of the audio files(.wav) + skip_header: boolean, default False + While reading from csv file, whether to skip at the start of the file to avoid reading in header + + + Attributes + ---------- + synsets : list + List of class names. `synsets[i]` is the name for the `i`th label + items : list of tuples + List of all audio in (filename, label) pairs. + + """ + def __init__(self, root, train_csv=None, file_format='.wav', skip_header=False): + if not librosa: + warnings.warn("pip install librosa to continue.") + raise RuntimeError("Librosa not installed. Run pip install librosa and retry this step.") + self._root = os.path.expanduser(root) + self._exts = ['.wav'] + self._format = file_format + self._train_csv = train_csv + if file_format.lower() not in self._exts: + raise RuntimeError("Format {} not supported currently.".format(file_format)) + skip_rows = 0 + if skip_header: + skip_rows = 1 + self._list_audio_files(self._root, skip_rows=skip_rows) + + + def _list_audio_files(self, root, skip_rows=0): + """Populates synsets - a map of index to label for the data items. + Populates the data in the dataset, making tuples of (data, label) + """ + self.synsets = [] + self.items = [] + if not self._train_csv: + # The audio files are organized in folder structure with + # directory name as label and audios in them + self._folder_structure(root) + else: + # train_csv contains mapping between filename and label + self._csv_labelled_dataset(root, skip_rows=skip_rows) + + # Generating the synset.txt file now + if not os.path.exists("./synset.txt"): + with open("./synset.txt", "w") as synsets_file: + for item in self.synsets: + synsets_file.write(item+os.linesep) + print("Synsets is generated as synset.txt") + else: + warnings.warn("Synset file already exists in the current directory! Not generating synset.txt.") + + + def _folder_structure(self, root): + for folder in sorted(os.listdir(root)): + path = os.path.join(root, folder) + if not os.path.isdir(path): + warnings.warn('Ignoring {}, which is not a directory.'.format(path)) + continue + label = len(self.synsets) + self.synsets.append(folder) + for filename in sorted(os.listdir(path)): + file_name = os.path.join(path, filename) + ext = os.path.splitext(file_name)[1] + if ext.lower() not in self._exts: + warnings.warn('Ignoring {} of type {}. Only support {}'\ + .format(filename, ext, ', '.join(self._exts))) + continue + self.items.append((file_name, label)) + + + def _csv_labelled_dataset(self, root, skip_rows=0): + with open(self._train_csv, "r") as traincsv: + for line in islice(csv.reader(traincsv), skip_rows, None): + filename = os.path.join(root, line[0]) + label = line[1].strip() + if label not in self.synsets: + self.synsets.append(label) + if self._format not in filename: + filename = filename+self._format + self.items.append((filename, nd.array([self.synsets.index(label)]).reshape((1,)))) + + + def __getitem__(self, idx): + """Retrieve the item (data, label) stored at idx in items""" + filename, label = self.items[idx] + # resampling_type is passed as kaiser_fast for a better performance + X1, _ = librosa.load(filename, res_type='kaiser_fast') + return nd.array(X1), label + + + def __len__(self): + """Retrieves the number of items in the dataset""" + return len(self.items) + + + def transform_first(self, fn, lazy=False): + """Returns a new dataset with the first element of each sample + transformed by the transformer function `fn`. + + This is useful, for example, when you only want to transform data + while keeping label as is. + lazy=False is passed to transform_first for dataset so that all tramsforms could be performed in + one shot and not during training. This is a performance consideration. + + Parameters + ---------- + fn : callable + A transformer function that takes the first element of a sample + as input and returns the transformed element. + lazy : bool, default False + If False, transforms all samples at once. Otherwise, + transforms each sample on demand. Note that if `fn` + is stochastic, you must set lazy to True or you will + get the same result on all epochs. + + Returns + ------- + Dataset + The transformed dataset. + + """ + return super(AudioFolderDataset, self).transform_first(fn, lazy=lazy) diff --git a/example/gluon/audio/urban_sounds/model.py b/example/gluon/audio/urban_sounds/model.py new file mode 100644 index 000000000000..af23cb946e2e --- /dev/null +++ b/example/gluon/audio/urban_sounds/model.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module builds a model an MLP with a configurable output layer( number of units in the last layer). +Users can pass any number of units in the last layer. SInce this dataset has 10 labels, +the default value of num_labels = 10 +""" +import mxnet as mx +from mxnet import gluon + +# Defining a neural network with number of labels +def get_net(num_labels=10): + net = gluon.nn.Sequential() + with net.name_scope(): + net.add(gluon.nn.Dense(256, activation="relu")) # 1st layer (256 nodes) + net.add(gluon.nn.Dense(256, activation="relu")) # 2nd hidden layer ( 256 nodes ) + net.add(gluon.nn.Dense(num_labels)) + net.collect_params().initialize(mx.init.Xavier()) + return net diff --git a/example/gluon/audio/urban_sounds/predict.py b/example/gluon/audio/urban_sounds/predict.py new file mode 100644 index 000000000000..0c3631173667 --- /dev/null +++ b/example/gluon/audio/urban_sounds/predict.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Prediction module for Urban Sounds Classification""" +from __future__ import print_function +import os +import sys +import warnings +import mxnet as mx +from mxnet import nd +from model import get_net +try: + import librosa +except ImportError: + raise ImportError("Librosa is not installed! please run the following command:\ + `pip install librosa`") +sys.path.append('../') + +def predict(prediction_dir='./Test'): + """The function is used to run predictions on the audio files in the directory `pred_directory`. + + Parameters + ---------- + net: + The model that has been trained. + prediction_dir: string, default ./Test + The directory that contains the audio files on which predictions are to be made + + """ + + if not os.path.exists(prediction_dir): + warnings.warn("The directory on which predictions are to be made is not found!") + return + + if len(os.listdir(prediction_dir)) == 0: + warnings.warn("The directory on which predictions are to be made is empty! Exiting...") + return + + # Loading synsets + if not os.path.exists('./synset.txt'): + warnings.warn("The synset or labels for the dataset do not exist. Please run the training script first.") + return + + with open("./synset.txt", "r") as f: + synset = [l.rstrip() for l in f] + net = get_net(len(synset)) + print("Trying to load the model with the saved parameters...") + if not os.path.exists("./net.params"): + warnings.warn("The model does not have any saved parameters... Cannot proceed! Train the model first") + return + + net.load_parameters("./net.params") + file_names = os.listdir(prediction_dir) + full_file_names = [os.path.join(prediction_dir, item) for item in file_names] + from transforms import MFCC + mfcc = MFCC() + print("\nStarting predictions for audio files in ", prediction_dir, " ....\n") + for filename in full_file_names: + # Argument kaiser_fast to res_type is faster than 'kaiser_best'. To reduce the load time, passing kaiser_fast. + X1, _ = librosa.load(filename, res_type='kaiser_fast') + transformed_test_data = mfcc(mx.nd.array(X1)) + output = net(transformed_test_data.reshape((1, -1))) + prediction = nd.argmax(output, axis=1) + print(filename, " -> ", synset[(int)(prediction.asscalar())]) + + +if __name__ == '__main__': + try: + import argparse + parser = argparse.ArgumentParser(description="Urban Sounds clsssification example - MXNet") + parser.add_argument('--pred', '-p', help="Enter the folder path that contains your audio files", type=str) + args = parser.parse_args() + pred_dir = args.pred + + except ImportError: + warnings.warn("Argparse module not installed! passing default arguments.") + pred_dir = './Test' + predict(prediction_dir=pred_dir) + print("Urban sounds classification Prediction DONE!") diff --git a/example/gluon/audio/urban_sounds/requirements.txt b/example/gluon/audio/urban_sounds/requirements.txt new file mode 100644 index 000000000000..d885e0beec7e --- /dev/null +++ b/example/gluon/audio/urban_sounds/requirements.txt @@ -0,0 +1,2 @@ +librosa>=0.6.2 # librosa is a library that is used to load the audio(wav) files and provides capabilities of feature extraction. +argparse # used for parsing arguments \ No newline at end of file diff --git a/example/gluon/audio/urban_sounds/train.py b/example/gluon/audio/urban_sounds/train.py new file mode 100644 index 000000000000..c88f9fb55187 --- /dev/null +++ b/example/gluon/audio/urban_sounds/train.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The module to run training on the Urban sounds dataset""" +from __future__ import print_function +import sys +import os +import time +import warnings +import mxnet as mx +from mxnet import gluon, nd, autograd +from datasets import AudioFolderDataset +import model +sys.path.append('../') + +def evaluate_accuracy(data_iterator, net): + """Function to evaluate accuracy of any data iterator passed to it as an argument""" + acc = mx.metric.Accuracy() + for data, label in data_iterator: + output = net(data) + predictions = nd.argmax(output, axis=1) + predictions = predictions.reshape((-1, 1)) + acc.update(preds=predictions, labels=label) + return acc.get()[1] + + +def train(train_dir=None, train_csv=None, epochs=30, batch_size=32): + """Function responsible for running the training the model.""" + + if not train_dir or not os.path.exists(train_dir) or not train_csv: + warnings.warn("No train directory could be found ") + return + # Make a dataset from the local folder containing Audio data + print("\nMaking an Audio Dataset...\n") + tick = time.time() + aud_dataset = AudioFolderDataset(train_dir, train_csv=train_csv, file_format='.wav', skip_header=True) + tock = time.time() + + print("Loading the dataset took ", (tock-tick), " seconds.") + print("\n=======================================\n") + print("Number of output classes = ", len(aud_dataset.synsets)) + print("\nThe labels are : \n") + print(aud_dataset.synsets) + # Get the model to train + net = model.get_net(len(aud_dataset.synsets)) + print("\nNeural Network = \n") + print(net) + print("\nModel - Neural Network Generated!\n") + print("=======================================\n") + + #Define the loss - Softmax CE Loss + softmax_loss = gluon.loss.SoftmaxCELoss(from_logits=False, sparse_label=True) + print("Loss function initialized!\n") + print("=======================================\n") + + #Define the trainer with the optimizer + trainer = gluon.Trainer(net.collect_params(), 'adadelta') + print("Optimizer - Trainer function initialized!\n") + print("=======================================\n") + print("Loading the dataset to the Gluon's OOTB Dataloader...") + + #Getting the data loader out of the AudioDataset and passing the transform + from transforms import MFCC + aud_transform = MFCC() + tick = time.time() + + audio_train_loader = gluon.data.DataLoader(aud_dataset.transform_first(aud_transform), batch_size=32, shuffle=True) + tock = time.time() + print("Time taken to load data and apply transform here is ", (tock-tick), " seconds.") + print("=======================================\n") + + + print("Starting the training....\n") + # Training loop + tick = time.time() + batch_size = batch_size + num_examples = len(aud_dataset) + + for epoch in range(epochs): + cumulative_loss = 0 + for data, label in audio_train_loader: + with autograd.record(): + output = net(data) + loss = softmax_loss(output, label) + loss.backward() + + trainer.step(batch_size) + cumulative_loss += mx.nd.sum(loss).asscalar() + + if epoch%5 == 0: + train_accuracy = evaluate_accuracy(audio_train_loader, net) + print("Epoch {}. Loss: {} Train accuracy : {} ".format(epoch, cumulative_loss/num_examples, train_accuracy)) + print("\n------------------------------\n") + + train_accuracy = evaluate_accuracy(audio_train_loader, net) + tock = time.time() + print("\nFinal training accuracy: ", train_accuracy) + + print("Training the sound classification for ", epochs, " epochs, MLP model took ", (tock-tick), " seconds") + print("====================== END ======================\n") + + print("Trying to save the model parameters here...") + net.save_parameters("./net.params") + print("Saved the model parameters in current directory.") + + +if __name__ == '__main__': + training_dir = './Train' + training_csv = './train.csv' + epochs = 30 + batch_size = 32 + + try: + import argparse + parser = argparse.ArgumentParser(description="Urban Sounds classification example - MXNet Gluon") + parser.add_argument('--train', '-t', help="Enter the folder path that contains your audio files", type=str) + parser.add_argument('--csv', '-c', help="Enter the filename of the csv that contains filename\ + to label mapping", type=str) + parser.add_argument('--epochs', '-e', help="Enter the number of epochs \ + you would want to run the training for.", type=int) + parser.add_argument('--batch_size', '-b', help="Enter the batch_size of data", type=int) + args = parser.parse_args() + + if args: + if args.train: + training_dir = args.train + + if args.csv: + training_csv = args.csv + + if args.epochs: + epochs = args.epochs + + if args.batch_size: + batch_size = args.batch_size + + + except ImportError as er: + warnings.warn("Argument parsing module could not be imported \ + Passing default arguments.") + + + train(train_dir=training_dir, train_csv=training_csv, epochs=epochs, batch_size=batch_size) + print("Urban sounds classification Training DONE!")