Skip to content

Commit

Permalink
[TTS] Add cosine distance option to TTS aligner
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed Jun 6, 2023
1 parent 9827c9b commit a7ef1fa
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 30 deletions.
2 changes: 2 additions & 0 deletions examples/tts/conf/fastpitch/fastpitch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ model:
alignment_module:
_target_: nemo.collections.tts.modules.aligner.AlignmentEncoder
n_text_channels: ${model.symbols_embedding_dim}
dist_type: cosine
temperature: 15.0

duration_predictor:
_target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
Expand Down
16 changes: 5 additions & 11 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.log_images = cfg.get("log_images", False)
self.log_train_images = False

loss_scale = 0.1 if self.learn_alignment else 1.0
dur_loss_scale = loss_scale
pitch_loss_scale = loss_scale
energy_loss_scale = loss_scale
if "dur_loss_scale" in cfg:
dur_loss_scale = cfg.dur_loss_scale
if "pitch_loss_scale" in cfg:
pitch_loss_scale = cfg.pitch_loss_scale
if "energy_loss_scale" in cfg:
energy_loss_scale = cfg.energy_loss_scale
default_prosody_loss_scale = 0.1 if self.learn_alignment else 1.0
dur_loss_scale = cfg.get("dur_loss_scale", default_prosody_loss_scale)
pitch_loss_scale = cfg.get("pitch_loss_scale", default_prosody_loss_scale)
energy_loss_scale = cfg.get("energy_loss_scale", default_prosody_loss_scale)

self.mel_loss_fn = MelLoss()
self.pitch_loss_fn = PitchLoss(loss_scale=pitch_loss_scale)
Expand All @@ -139,7 +133,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self.aligner = None
if self.learn_alignment:
aligner_loss_scale = cfg.aligner_loss_scale if "aligner_loss_scale" in cfg else 1.0
aligner_loss_scale = cfg.get("aligner_loss_scale", 1.0)
self.aligner = instantiate(self._cfg.alignment_module)
self.forward_sum_loss_fn = ForwardSumLoss(loss_scale=aligner_loss_scale)
self.bin_loss_fn = BinLoss(loss_scale=aligner_loss_scale)
Expand Down
78 changes: 59 additions & 19 deletions nemo/collections/tts/modules/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import torch
from einops import rearrange
from torch import nn

from nemo.collections.tts.modules.submodules import ConditionalInput, ConvNorm
Expand All @@ -24,7 +25,13 @@ class AlignmentEncoder(torch.nn.Module):
"""Module for alignment text and mel spectrogram. """

def __init__(
self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=0.0005, condition_types=[]
self,
n_mel_channels=80,
n_text_channels=512,
n_att_channels=80,
temperature=0.0005,
condition_types=[],
dist_type="l2",
):
super().__init__()
self.temperature = temperature
Expand All @@ -45,6 +52,20 @@ def __init__(
torch.nn.ReLU(),
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
)
if dist_type == "l2":
self.dist_fn = self.get_euclidean_dist
elif dist_type == "cosine":
self.dist_fn = self.get_cosine_dist
else:
raise ValueError(f"Unknown distance type '{dist_type}'")

@staticmethod
def _apply_mask(inputs, mask, mask_value):
if mask is None:
return

mask = rearrange(mask, "B T2 1 -> B 1 1 T2")
inputs.data.masked_fill_(mask, mask_value)

def get_dist(self, keys, queries, mask=None):
"""Calculation of distance matrix.
Expand All @@ -57,15 +78,35 @@ def get_dist(self, keys, queries, mask=None):
Output:
dist (torch.tensor): B x T1 x T2 tensor.
"""
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
queries_enc = self.query_proj(queries) # B x n_attn_dims x T1
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
dist = attn.sum(1, keepdim=True) # B x 1 x T1 x T2
# B x n_attn_dims x T2
keys_enc = self.key_proj(keys)
# B x n_attn_dims x T1
queries_enc = self.query_proj(queries)

# B x 1 x T1 x T2
dist = self.dist_fn(queries=queries_enc, keys=keys_enc)

self._apply_mask(dist, mask, float("inf"))

if mask is not None:
dist.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), float("inf"))
return dist

return dist.squeeze(1)
@staticmethod
def get_euclidean_dist(queries, keys):
queries = rearrange(queries, "B C T1 -> B C T1 1")
keys = rearrange(keys, "B C T2 -> B C 1 T2")
# B x C x T1 x T2
distance = (queries - keys) ** 2
# B x 1 x T1 x T2
l2_dist = distance.sum(axis=1, keepdim=True)
return l2_dist

@staticmethod
def get_cosine_dist(queries, keys):
queries = rearrange(queries, "B C T1 -> B C T1 1")
keys = rearrange(keys, "B C T2 -> B C 1 T2")
cosine_dist = -torch.nn.functional.cosine_similarity(queries, keys, dim=1)
cosine_dist = rearrange(cosine_dist, "B T1 T2 -> B 1 T1 T2")
return cosine_dist

@staticmethod
def get_durations(attn_soft, text_len, spect_len):
Expand Down Expand Up @@ -96,8 +137,7 @@ def get_mean_dist_by_durations(dist, durations, mask=None):
batch_size, t1_size, t2_size = dist.size()
assert torch.all(torch.eq(durations.sum(dim=1), t1_size))

if mask is not None:
dist = dist.masked_fill(mask.permute(0, 2, 1).unsqueeze(2), 0)
AlignmentEncoder._apply_mask(dist, mask, 0)

# TODO(oktai15): make it more efficient
mean_dist_by_durations = []
Expand Down Expand Up @@ -149,7 +189,7 @@ def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None):
"""Forward pass of the aligner encoder.
Args:
queries (torch.tensor): B x C x T1 tensor (probably going to be mel data).
queries (torch.tensor): B x C1 x T1 tensor (probably going to be mel data).
keys (torch.tensor): B x C2 x T2 tensor (text data).
mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries (True = mask element, False = leave unchanged).
attn_prior (torch.tensor): prior for attention matrix.
Expand All @@ -159,20 +199,20 @@ def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None):
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask.
"""
keys = self.cond_input(keys.transpose(1, 2), conditioning).transpose(1, 2)
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
queries_enc = self.query_proj(queries) # B x n_attn_dims x T1

# Simplistic Gaussian Isotopic Attention
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
attn = -self.temperature * attn.sum(1, keepdim=True)
# B x C x T2
keys_enc = self.key_proj(keys)
# B x C x T1
queries_enc = self.query_proj(queries)
# B x 1 x T1 x T2
distance = self.dist_fn(queries=queries_enc, keys=keys_enc)
attn = -self.temperature * distance

if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)

attn_logprob = attn.clone()

if mask is not None:
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf"))
self._apply_mask(attn, mask, -float("inf"))

attn = self.softmax(attn) # softmax along T2
return attn, attn_logprob

0 comments on commit a7ef1fa

Please sign in to comment.