Skip to content

Commit

Permalink
override compute_gst in tacotron2 model
Browse files Browse the repository at this point in the history
  • Loading branch information
lexkoro committed Jul 13, 2020
1 parent c4a0f4d commit c5eaf12
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
25 changes: 24 additions & 1 deletion models/tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,22 @@ def __init__(self,
decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim
encoder_in_features = 512 if num_speakers > 1 else 512
proj_speaker_dim = 80 if num_speakers > 1 else 0
# base layers

# embedding layer
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)

# speaker embedding layer
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)

self.encoder = Encoder(encoder_in_features)
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask,
location_attn, attn_K, separate_stopnet, proj_speaker_dim)
self.postnet = Postnet(self.postnet_output_dim)

# global style token layers
if self.gst:
self.gst_layer = GST(num_mel=80,
Expand All @@ -81,6 +86,24 @@ def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments

def compute_gst(self, inputs, style_input):
""" Compute global style token """
device = inputs.device
if isinstance(style_input, dict):
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
for k_token, v_amplifier in style_input.items():
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
else:
gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
return inputs, embedded_gst

def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
# compute mask for padding
# B x T_in_max (boolean)
Expand Down
22 changes: 6 additions & 16 deletions models/tacotron_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,12 @@ def compute_speaker_embedding(self, speaker_ids):
self.speaker_embeddings_projected = self.speaker_project_mel(
self.speaker_embeddings).squeeze(1)

def compute_gst(self, inputs, style_input):
device = inputs.device
if isinstance(style_input, dict):
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
for k_token, v_amplifier in style_input.items():
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
else:
gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
return inputs, embedded_gst
def compute_gst(self, inputs, mel_specs):
""" Compute global style token """
# pylint: disable=not-callable
gst_outputs = self.gst_layer(mel_specs)
inputs = self._add_speaker_embedding(inputs, gst_outputs)
return inputs

@staticmethod
def _add_speaker_embedding(outputs, speaker_embeddings):
Expand Down

0 comments on commit c5eaf12

Please sign in to comment.