forked from Kyubyong/tacotron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_load.py
119 lines (97 loc) · 4.23 KB
/
data_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
By kyubyong park. [email protected].
https://www.github.com/kyubyong/tacotron
'''
from __future__ import print_function
from hyperparams import Hyperparams as hp
import numpy as np
import tensorflow as tf
from utils import *
import codecs
import re
import os
import unicodedata
def load_vocab():
char2idx = {char: idx for idx, char in enumerate(hp.vocab)}
idx2char = {idx: char for idx, char in enumerate(hp.vocab)}
return char2idx, idx2char
def text_normalize(text):
text = ''.join(char for char in unicodedata.normalize('NFD', text)
if unicodedata.category(char) != 'Mn') # Strip accents
text = text.lower()
text = re.sub("[^{}]".format(hp.vocab), " ", text)
text = re.sub("[ ]+", " ", text)
return text
def load_data(mode="train"):
# Load vocabulary
char2idx, idx2char = load_vocab()
if mode in ("train", "eval"):
# Parse
fpaths, text_lengths, texts = [], [], []
transcript = os.path.join(hp.data, 'transcript.csv')
lines = codecs.open(transcript, 'r', 'utf-8').readlines()
total_hours = 0
if mode=="train":
lines = lines[1:]
else: # We attack only one sample!
lines = lines[:1]
for line in lines:
fname, _, text = line.strip().split("|")
fpath = os.path.join(hp.data, "wavs", fname + ".wav")
fpaths.append(fpath)
text = text_normalize(text) + "E" # E: EOS
text = [char2idx[char] for char in text]
text_lengths.append(len(text))
texts.append(np.array(text, np.int32).tostring())
return fpaths, text_lengths, texts
else:
# Parse
lines = codecs.open(hp.test_data, 'r', 'utf-8').readlines()[1:]
sents = [text_normalize(line.split(" ", 1)[-1]).strip() + "E" for line in lines] # text normalization, E: EOS
lengths = [len(sent) for sent in sents]
maxlen = sorted(lengths, reverse=True)[0]
texts = np.zeros((len(sents), maxlen), np.int32)
for i, sent in enumerate(sents):
texts[i, :len(sent)] = [char2idx[char] for char in sent]
return texts
def get_batch():
"""Loads training data and put them in queues"""
with tf.device('/cpu:0'):
# Load data
fpaths, text_lengths, texts = load_data() # list
maxlen, minlen = max(text_lengths), min(text_lengths)
# Calc total batch count
num_batch = len(fpaths) // hp.batch_size
fpaths = tf.convert_to_tensor(fpaths)
text_lengths = tf.convert_to_tensor(text_lengths)
texts = tf.convert_to_tensor(texts)
# Create Queues
fpath, text_length, text = tf.train.slice_input_producer([fpaths, text_lengths, texts], shuffle=True)
# Parse
text = tf.decode_raw(text, tf.int32) # (None,)
if hp.prepro:
def _load_spectrograms(fpath):
fname = os.path.basename(fpath)
mel = "mels/{}".format(fname.replace("wav", "npy"))
mag = "mags/{}".format(fname.replace("wav", "npy"))
return fname, np.load(mel), np.load(mag)
fname, mel, mag = tf.py_func(_load_spectrograms, [fpath], [tf.string, tf.float32, tf.float32])
else:
fname, mel, mag = tf.py_func(load_spectrograms, [fpath], [tf.string, tf.float32, tf.float32]) # (None, n_mels)
# Add shape information
fname.set_shape(())
text.set_shape((None,))
mel.set_shape((None, hp.n_mels*hp.r))
mag.set_shape((None, hp.n_fft//2+1))
# Batching
_, (texts, mels, mags, fnames) = tf.contrib.training.bucket_by_sequence_length(
input_length=text_length,
tensors=[text, mel, mag, fname],
batch_size=hp.batch_size,
bucket_boundaries=[i for i in range(minlen + 1, maxlen - 1, 20)],
num_threads=16,
capacity=hp.batch_size * 4,
dynamic_pad=True)
return texts, mels, mags, fnames, num_batch