Skip to content

Commit

Permalink
Rename g as spk_emb
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed May 17, 2022
1 parent 2d29e82 commit 8adcd1d
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions TTS/tts/models/forward_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ class ForwardTTS(BaseTTS):
If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each
input character as in the FastPitch model.
::
|-----> (optional) PitchPredictor(o_en, spk_emb) --> pitch_emb --> o_en = o_en + pitch_emb-----| -> CondConv(spk_emb) -> spk_proj
spk, text -> Encoder(text, spk)--> o_en, spk_emb -----> DurationPredictor(o_en, spk_emb)--> dur -------------------------> Expand(o_en, dur) -> PositionEncoding(o_en_expand) -> Decoder(o_en_expand_pos, spk_proj) -> mel_out
`ForwardTTS` can be configured to one of these architectures,
- FastPitch
Expand Down Expand Up @@ -610,19 +615,19 @@ def forward(
- g: :math:`[B, C]`
- pitch: :math:`[B, 1, T]`
"""
g = self._set_speaker_input(aux_input)
spk = self._set_speaker_input(aux_input)
# compute sequence masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max]
# encoder pass
x_emb, x_mask, g, o_en = self._forward_encoder(
x, x_mask, g
x_emb, x_mask, spk_emb, o_en = self._forward_encoder(
x, x_mask, spk
) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max]
# duration predictor pass
if self.args.detach_duration_predictor:
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max]
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=spk_emb) # [B, 1, T_max]
else:
o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max]
o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=spk_emb) # [B, 1, T_max]
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# generate attn mask from predicted durations
dur_predictor_attn = self.generate_attn(o_dr.squeeze(1), x_mask) # [B, T_max, T_max2']
Expand All @@ -644,18 +649,18 @@ def forward(
avg_pitch = None
if self.args.use_pitch:
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(
o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g
o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=spk_emb
)
o_en = o_en + o_pitch_emb
# expand encoder outputs
y_mask, o_en_ex, attn = self._expand_encoder(
o_en=o_en, y_lengths=y_lengths, dr=dr, x_mask=x_mask
) # [B, 1, T_max2], [B, C_en, T_max2], [B, T_max2, T_max]
# decoder pass
o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g) # [B, T_max2, C_de]
o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=spk_emb) # [B, T_max2, C_de]
outputs = {
"model_outputs": o_de, # [B, T, C]
"g": g, # [B, C]
"spk_emb": spk_emb, # [B, C]
"durations_log": o_dr_log.squeeze(1), # [B, T]
"durations": o_dr.squeeze(1), # [B, T]
"attn_durations": dur_predictor_attn, # for visualization [B, T_en, T_de']
Expand Down Expand Up @@ -688,33 +693,33 @@ def inference(
- x_lengths: [B]
- g: [B, C]
"""
g = self._set_speaker_input(aux_input)
spk = self._set_speaker_input(aux_input)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
# encoder pass
_, x_mask, g, o_en = self._forward_encoder(x, x_mask, g)
_, x_mask, spk_emb, o_en = self._forward_encoder(x, x_mask, spk)
# duration predictor pass
o_dr_log = self.duration_predictor(o_en, x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
# pitch predictor pass
o_pitch = None
if self.args.use_pitch:
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en=o_en, x_mask=x_mask, g=spk_emb)
o_en = o_en + o_pitch_emb
# expand encoder outputs
y_mask, o_en_ex, attn = self._expand_encoder(o_en=o_en, y_lengths=y_lengths, dr=o_dr, x_mask=x_mask)
outputs = {
"alignments": attn,
"pitch": o_pitch,
"durations": o_dr,
"g": g,
"spk_emb": spk_emb,
}
if skip_decoder:
outputs["o_en_ex"] = o_en_ex
else:
# decoder pass
outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g)
outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=spk_emb)
return outputs

def train_step(self, batch: dict, criterion: nn.Module):
Expand Down

0 comments on commit 8adcd1d

Please sign in to comment.