Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sorry , I cannot reprocess the dataset,using thchs I cannot solve this problem. #315

Open
wants to merge 41 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5d881fa
For mandarin
begeekmyfriend Jan 29, 2018
5aa70e8
Update symbolic characters
begeekmyfriend Jan 30, 2018
3c2556d
THCHS30
begeekmyfriend Feb 6, 2018
59c1800
Update
begeekmyfriend Feb 18, 2018
276dacc
Adjust min silence time span for end point
begeekmyfriend Feb 19, 2018
05de2ec
Expand data set
begeekmyfriend Mar 1, 2018
7c00023
Adjust fft sampling points and frame length according to audio SR
begeekmyfriend Mar 28, 2018
cbc2b87
Change mel fbank back 80 for evaluation
begeekmyfriend Apr 9, 2018
cdcfac1
Update
begeekmyfriend Apr 10, 2018
dead31e
Merge remote-tracking branch 'upstream/master'
begeekmyfriend Aug 17, 2018
b02eee7
Replace attention wrapper with location sensitive attention
begeekmyfriend Aug 17, 2018
f8de0d7
Update
begeekmyfriend Aug 17, 2018
a614f95
Use 0-1 normalization
begeekmyfriend Aug 19, 2018
8a8696a
Adjust initial learning rate
begeekmyfriend Aug 19, 2018
4d1c14a
Fix sythesis bug
begeekmyfriend Aug 24, 2018
5fc4708
Add stop token target for learning when to stop decoding
begeekmyfriend Aug 27, 2018
41b385b
Revert to inv_spectrogram_tensorflow for speed advantage
begeekmyfriend Aug 28, 2018
da7abd0
Replace AttentionWrapper with location sensitive attention
begeekmyfriend Aug 28, 2018
78bec62
librosa 0.6+ only supports floating mode
begeekmyfriend Aug 28, 2018
2e7abe3
Automatically restore from the last checkpoint
begeekmyfriend Aug 28, 2018
3ca25ef
Add stop token target to learn when to stop decoding
begeekmyfriend Aug 28, 2018
1f8d32f
Apply regularization L2 with weight 1e-6 according to tacotron 2
begeekmyfriend Aug 28, 2018
9340d42
Fix dropout training parameter bug
begeekmyfriend Aug 30, 2018
e7e1ee5
Fix dropout bug
begeekmyfriend Sep 2, 2018
cec3ac3
Remove find endpoint method
begeekmyfriend Sep 26, 2018
b0461d8
Cancel clipping for normalization
begeekmyfriend Sep 29, 2018
b0a26f6
Fix bug
begeekmyfriend Nov 28, 2018
59d6cb3
Add teacher forcing ratio for overcoming overfitting
begeekmyfriend Nov 30, 2018
2dab01c
Add L2 regularization
begeekmyfriend Nov 30, 2018
3765d6c
Merge branch 'mandarin' of https://github.com/begeekmyfriend/tacotron…
begeekmyfriend Nov 30, 2018
bd2d80e
Symmetric mels helps quick alignment
begeekmyfriend Dec 1, 2018
b055e85
Update hyper parameter for stronger fitting
begeekmyfriend Dec 3, 2018
86a89a2
Add dropout in convolution for regularization
begeekmyfriend Dec 4, 2018
ca442ca
Add wav rescaling for unified measure
begeekmyfriend Dec 5, 2018
2a88f48
Rescaling for unified measure for all clips
begeekmyfriend Dec 5, 2018
23c5df2
Add FIR as bandpass filter for less noises
begeekmyfriend Dec 6, 2018
0fbbb8c
Fix bug
begeekmyfriend Dec 10, 2018
c4632e9
Merge branch 'mandarin'
begeekmyfriend Dec 10, 2018
cabc77b
Add bandwidth limitation for corpus
begeekmyfriend Dec 13, 2018
e093cf0
Remove dropout in conv1d to reduce loss
begeekmyfriend Apr 15, 2019
60d6932
Loading latest checkpoint on synthesis
begeekmyfriend Apr 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions datasets/datafeeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_batches_per_group = 32
_p_cmudict = 0.5
_pad = 0
_stop_token_pad = 1


class DataFeeder(threading.Thread):
Expand All @@ -37,17 +38,19 @@ def __init__(self, coordinator, metadata_filename, hparams):
tf.placeholder(tf.int32, [None, None], 'inputs'),
tf.placeholder(tf.int32, [None], 'input_lengths'),
tf.placeholder(tf.float32, [None, None, hparams.num_mels], 'mel_targets'),
tf.placeholder(tf.float32, [None, None, hparams.num_freq], 'linear_targets')
tf.placeholder(tf.float32, [None, None, hparams.num_freq], 'linear_targets'),
tf.placeholder(tf.float32, [None, None], 'stop_token_targets')
]

# Create queue for buffering data:
queue = tf.FIFOQueue(8, [tf.int32, tf.int32, tf.float32, tf.float32], name='input_queue')
queue = tf.FIFOQueue(8, [tf.int32, tf.int32, tf.float32, tf.float32, tf.float32], name='input_queue')
self._enqueue_op = queue.enqueue(self._placeholders)
self.inputs, self.input_lengths, self.mel_targets, self.linear_targets = queue.dequeue()
self.inputs, self.input_lengths, self.mel_targets, self.linear_targets, self.stop_token_targets = queue.dequeue()
self.inputs.set_shape(self._placeholders[0].shape)
self.input_lengths.set_shape(self._placeholders[1].shape)
self.mel_targets.set_shape(self._placeholders[2].shape)
self.linear_targets.set_shape(self._placeholders[3].shape)
self.stop_token_targets.set_shape(self._placeholders[4].shape)

# Load CMUDict: If enabled, this will randomly substitute some words in the training data with
# their ARPABet equivalents, which will allow you to also pass ARPABet to the model for
Expand Down Expand Up @@ -97,7 +100,7 @@ def _enqueue_next_group(self):


def _get_next_example(self):
'''Loads a single example (input, mel_target, linear_target, cost) from disk'''
'''Loads a single example (input, mel_target, linear_target, stop_token_target) from disk'''
if self._offset >= len(self._metadata):
self._offset = 0
random.shuffle(self._metadata)
Expand All @@ -111,7 +114,8 @@ def _get_next_example(self):
input_data = np.asarray(text_to_sequence(text, self._cleaner_names), dtype=np.int32)
linear_target = np.load(os.path.join(self._datadir, meta[0]))
mel_target = np.load(os.path.join(self._datadir, meta[1]))
return (input_data, mel_target, linear_target, len(linear_target))
stop_token_target = np.asarray([0.] * len(mel_target))
return (input_data, mel_target, linear_target, stop_token_target, len(linear_target))


def _maybe_get_arpabet(self, word):
Expand All @@ -125,7 +129,8 @@ def _prepare_batch(batch, outputs_per_step):
input_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32)
mel_targets = _prepare_targets([x[1] for x in batch], outputs_per_step)
linear_targets = _prepare_targets([x[2] for x in batch], outputs_per_step)
return (inputs, input_lengths, mel_targets, linear_targets)
stop_token_targets = _prepare_stop_token_targets([x[3] for x in batch], outputs_per_step)
return (inputs, input_lengths, mel_targets, linear_targets, stop_token_targets)


def _prepare_inputs(inputs):
Expand All @@ -138,6 +143,11 @@ def _prepare_targets(targets, alignment):
return np.stack([_pad_target(t, _round_up(max_len, alignment)) for t in targets])


def _prepare_stop_token_targets(targets, alignment):
max_len = max((len(t) for t in targets)) + 1
return np.stack([_pad_stop_token_target(t, _round_up(max_len, alignment)) for t in targets])


def _pad_input(x, length):
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)

Expand All @@ -146,6 +156,10 @@ def _pad_target(t, length):
return np.pad(t, [(0, length - t.shape[0]), (0,0)], mode='constant', constant_values=_pad)


def _pad_stop_token_target(t, length):
return np.pad(t, (0, length - t.shape[0]), mode='constant', constant_values=_stop_token_pad)


def _round_up(x, multiple):
remainder = x % multiple
return x if remainder == 0 else x + multiple - remainder
17 changes: 10 additions & 7 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os
import re
import tensorflow as tf
from hparams import hparams, hparams_debug_string
from synthesizer import Synthesizer

Expand All @@ -19,34 +20,36 @@
]


def get_output_base_path(checkpoint_path):
def get_output_base_path(ckpt_path):
base_dir = os.path.dirname(checkpoint_path)
m = re.compile(r'.*?\.ckpt\-([0-9]+)').match(checkpoint_path)
name = 'eval-%d' % int(m.group(1)) if m else 'eval'
return os.path.join(base_dir, name)


def run_eval(args):
def run_eval(ckpt_dir):
print(hparams_debug_string())
checkpoint = tf.train.get_checkpoint_state(ckpt_dir).model_checkpoint_path
synth = Synthesizer()
synth.load(args.checkpoint)
base_path = get_output_base_path(args.checkpoint)
synth.load(checkpoint)
base_path = get_output_base_path(checkpoint)
for i, text in enumerate(sentences):
path = '%s-%d.wav' % (base_path, i)
path = '%s-%03d.wav' % (base_path, i)
print('Synthesizing: %s' % path)
with open(path, 'wb') as f:
f.write(synth.synthesize(text))


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint')
parser.add_argument('--checkpoint', default='logs-tacotron', help='Path to model checkpoint')
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
args = parser.parse_args()
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
hparams.parse(args.hparams)
run_eval(args)
run_eval(args.checkpoint)


if __name__ == '__main__':
Expand Down
25 changes: 15 additions & 10 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,41 @@
cleaners='english_cleaners',

# Audio:
num_mels=80,
num_mels=160,
num_freq=1025,
sample_rate=20000,
sample_rate=24000,
frame_length_ms=50,
frame_shift_ms=12.5,
preemphasis=0.97,
min_level_db=-100,
ref_level_db=20,
max_frame_num=1000,
max_abs_value = 4,
fmin = 125, # for male, set 55
fmax = 7600, # for male, set 3600

# Model:
outputs_per_step=5,
embed_depth=256,
prenet_depths=[256, 128],
embed_depth=512,
prenet_depths=[256, 256],
encoder_depth=256,
postnet_depth=256,
attention_depth=256,
decoder_depth=256,
postnet_depth=512,
attention_depth=128,
decoder_depth=1024,

# Training:
batch_size=32,
adam_beta1=0.9,
adam_beta2=0.999,
initial_learning_rate=0.002,
reg_weight = 1e-6,
initial_learning_rate=0.001,
decay_learning_rate=True,
use_cmudict=False, # Use CMUDict during training to learn pronunciation of ARPAbet phonemes

# Eval:
max_iters=200,
max_iters=300,
griffin_lim_iters=60,
power=1.5, # Power to raise magnitudes to prior to Griffin-Lim
power=1.2, # Power to raise magnitudes to prior to Griffin-Lim
)


Expand Down
201 changes: 201 additions & 0 deletions models/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Attention file for location based attention (compatible with tensorflow attention wrapper)"""

import tensorflow as tf
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import BahdanauAttention
from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops, math_ops, nn_ops, variable_scope


#From https://github.com/tensorflow/tensorflow/blob/r1.7/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
def _compute_attention(attention_mechanism, cell_output, attention_state, attention_layer):
"""Computes the attention and alignments for a given attention_mechanism."""
alignments, next_attention_state = attention_mechanism(
cell_output, state=attention_state)

# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, memory_size]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, memory_size].
# we then squeeze out the singleton dim.
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1])

if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
else:
attention = context

return attention, alignments, next_attention_state


def _location_sensitive_score(W_query, W_fil, W_keys):
"""Impelements Bahdanau-style (cumulative) scoring function.
This attention is described in:
J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
gio, “Attention-based models for speech recognition,” in Ad-
vances in Neural Information Processing Systems, 2015, pp.
577–585.

#############################################################################
hybrid attention (content-based + location-based)
f = F * α_{i-1}
energy = dot(v_a, tanh(W_keys(h_enc) + W_query(h_dec) + W_fil(f) + b_a))
#############################################################################

Args:
W_query: Tensor, shape '[batch_size, 1, attention_dim]' to compare to location features.
W_location: processed previous alignments into location features, shape '[batch_size, max_time, attention_dim]'
W_keys: Tensor, shape '[batch_size, max_time, attention_dim]', typically the encoder outputs.
Returns:
A '[batch_size, max_time]' attention score (energy)
"""
# Get the number of hidden units from the trailing dimension of keys
dtype = W_query.dtype
num_units = W_keys.shape[-1].value or array_ops.shape(W_keys)[-1]

v_a = tf.get_variable(
'attention_variable', shape=[num_units], dtype=dtype,
initializer=tf.contrib.layers.xavier_initializer())
b_a = tf.get_variable(
'attention_bias', shape=[num_units], dtype=dtype,
initializer=tf.zeros_initializer())

return tf.reduce_sum(v_a * tf.tanh(W_keys + W_query + W_fil + b_a), [2])

def _smoothing_normalization(e):
"""Applies a smoothing normalization function instead of softmax
Introduced in:
J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
gio, “Attention-based models for speech recognition,” in Ad-
vances in Neural Information Processing Systems, 2015, pp.
577–585.

############################################################################
Smoothing normalization function
a_{i, j} = sigmoid(e_{i, j}) / sum_j(sigmoid(e_{i, j}))
############################################################################

Args:
e: matrix [batch_size, max_time(memory_time)]: expected to be energy (score)
values of an attention mechanism
Returns:
matrix [batch_size, max_time]: [0, 1] normalized alignments with possible
attendance to multiple memory time steps.
"""
return tf.nn.sigmoid(e) / tf.reduce_sum(tf.nn.sigmoid(e), axis=-1, keepdims=True)


class LocationSensitiveAttention(BahdanauAttention):
"""Impelements Bahdanau-style (cumulative) scoring function.
Usually referred to as "hybrid" attention (content-based + location-based)
Extends the additive attention described in:
"D. Bahdanau, K. Cho, and Y. Bengio, “Neural machine transla-
tion by jointly learning to align and translate,” in Proceedings
of ICLR, 2015."
to use previous alignments as additional location features.

This attention is described in:
J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
gio, “Attention-based models for speech recognition,” in Ad-
vances in Neural Information Processing Systems, 2015, pp.
577–585.
"""

def __init__(self,
num_units,
memory,
smoothing=False,
cumulate_weights=True,
name='LocationSensitiveAttention'):
"""Construct the Attention mechanism.
Args:
num_units: The depth of the query mechanism.
memory: The memory to query; usually the output of an RNN encoder. This
tensor should be shaped `[batch_size, max_time, ...]`.
memory_sequence_length (optional): Sequence lengths for the batch entries
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths. Only relevant if mask_encoder = True.
smoothing (optional): Boolean. Determines which normalization function to use.
Default normalization function (probablity_fn) is softmax. If smoothing is
enabled, we replace softmax with:
a_{i, j} = sigmoid(e_{i, j}) / sum_j(sigmoid(e_{i, j}))
Introduced in:
J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
gio, “Attention-based models for speech recognition,” in Ad-
vances in Neural Information Processing Systems, 2015, pp.
577–585.
This is mainly used if the model wants to attend to multiple inputs parts
at the same decoding step. We probably won't be using it since multiple sound
frames may depend from the same character, probably not the way around.
Note:
We still keep it implemented in case we want to test it. They used it in the
paper in the context of speech recognition, where one phoneme may depend on
multiple subsequent sound frames.
name: Name to use when creating ops.
"""
#Create normalization function
#Setting it to None defaults in using softmax
normalization_function = _smoothing_normalization if (smoothing == True) else None
super(LocationSensitiveAttention, self).__init__(
num_units=num_units,
memory=memory,
memory_sequence_length=None,
probability_fn=normalization_function,
name=name)

self.location_convolution = tf.layers.Conv1D(filters=32,
kernel_size=(31, ), padding='same', use_bias=True,
bias_initializer=tf.zeros_initializer(), name='location_features_convolution')
self.location_layer = tf.layers.Dense(units=num_units, use_bias=False,
dtype=tf.float32, name='location_features_layer')
self._cumulate = cumulate_weights

def __call__(self, query, state):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
state (previous alignments): Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns:
alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
"""
previous_alignments = state
with variable_scope.variable_scope(None, "Location_Sensitive_Attention", [query]):

# processed_query shape [batch_size, query_depth] -> [batch_size, attention_dim]
processed_query = self.query_layer(query) if self.query_layer else query
# -> [batch_size, 1, attention_dim]
processed_query = tf.expand_dims(processed_query, 1)

# processed_location_features shape [batch_size, max_time, attention dimension]
# [batch_size, max_time] -> [batch_size, max_time, 1]
expanded_alignments = tf.expand_dims(previous_alignments, axis=2)
# location features [batch_size, max_time, filters]
f = self.location_convolution(expanded_alignments)
# Projected location features [batch_size, max_time, attention_dim]
processed_location_features = self.location_layer(f)

# energy shape [batch_size, max_time]
energy = _location_sensitive_score(processed_query, processed_location_features, self.keys)


# alignments shape = energy shape = [batch_size, max_time]
alignments = self._probability_fn(energy, previous_alignments)

# Cumulate alignments
if self._cumulate:
next_state = alignments + previous_alignments
else:
next_state = alignments

return alignments, next_state
Loading