-
Notifications
You must be signed in to change notification settings - Fork 0
/
wav_dataset.py
102 lines (78 loc) · 3.23 KB
/
wav_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from mxnet.gluon.data import dataset
from mxnet import nd
from scipy.io.wavfile import read
from dataset.preprocessing.text import text_to_sequence
from model.layers.tacotron_stft import TacotronSTFT
from params import tacotron_params
import os
import glob
import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt
class WavDataset(dataset.Dataset):
def __init__(self, root, text_transforms, text_cleaners=['english_cleaners'], *args, **kwargs):
super(WavDataset, self).__init__()
self._root = os.path.expanduser(root)
self._exts = ['.wav', '.txt']
self._max_wav_value = kwargs['max_wav_value']
del kwargs['max_wav_value']
self._stft = TacotronSTFT(**kwargs)
self._text_cleaners = text_cleaners
self._text_transforms = text_transforms
self._items = []
self._list_records()
def _list_records(self):
pattern = os.path.join(self._root, '*{}'.format(self._exts[0]))
wav_files = glob.glob(pattern)
logging.info('{} sample(s) found in dataset'.format(len(wav_files)))
for wav_file in wav_files:
text_file = '{}{}'.format(os.path.splitext(wav_file)[0], self._exts[1])
if os.path.isfile(text_file):
self._items.append((wav_file, self._load_text(text_file)))
def _load_wav(self, file_path):
with warnings.catch_warnings():
warnings.simplefilter('ignore')
sampling_rate, data = read(file_path)
return nd.array(data.astype(np.float32))
def _load_text(self, file_path):
with open(file_path, encoding='utf-8') as f:
text = [line.strip() for line in f]
text = ' '.join(text)
if self._text_transforms is not None:
text = self._text_transforms(text, self._text_cleaners)
return nd.array(text, dtype=np.int)
def __getitem__(self, idx):
wav_file, encoded_text = self._items[idx]
wav = self._load_wav(wav_file)
audio_normalized = wav / self._max_wav_value
audio_normalized = audio_normalized.expand_dims(0)
#audio_normalized = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = self._stft.mel_spectrogram(audio_normalized).squeeze()
return melspec, encoded_text
def __len__(self):
return len(self._items)
if __name__ == '__main__':
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
params = {
'max_wav_value': 32768.0, # for 16 bits files
'sampling_rate': 22050,
'filter_length': 1024,
'hop_length': 256,
'win_length': 1024,
'n_mel_channels': 80,
'mel_fmin': 0.0,
'mel_fmax': 8000.0
}
french = WavDataset('~/datasets/tacotron', text_to_sequence, **params)
assert type(french[0]) == tuple
for i, (data, label) in enumerate(french):
print('Plotting sample {}'.format(i))
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20, 4))
ax.set_title('Mel spectrogram')
ax.set_xlabel('Frames')
ax.set_ylabel('Mel channels')
cax = ax.matshow(data.asnumpy(), interpolation='nearest', aspect='auto', cmap='viridis', origin='lower')
fig.colorbar(cax)
plt.show()