Skip to content

Commit

Permalink
[TTS] Add period discriminator and feature matching loss to codec recipe
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed Nov 14, 2023
1 parent 77d1386 commit 35d0e41
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 3 deletions.
51 changes: 51 additions & 0 deletions nemo/collections/tts/losses/audio_codec_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,58 @@ def forward(self, audio_real, audio_gen, audio_len):
return loss


class FeatureMatchingLoss(Loss):
"""
Standard feature matching loss measuring the difference in the internal discriminator layer outputs
(usually leaky relu activations) between real and generated audio, scaled down by the total number of
discriminators and layers.
"""

def __init__(self):
super(FeatureMatchingLoss, self).__init__()

@property
def input_types(self):
return {
"fmaps_real": [[NeuralType(elements_type=VoidType())]],
"fmaps_gen": [[NeuralType(elements_type=VoidType())]],
}

@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
}

@typecheck()
def forward(self, fmaps_real, fmaps_gen):
loss = 0.0
for fmap_real, fmap_gen in zip(fmaps_real, fmaps_gen):
# [B, ..., time]
for feat_real, feat_gen in zip(fmap_real, fmap_gen):
# [B, ...]
diff = torch.abs(feat_real - feat_gen)
feat_loss = torch.mean(diff) / len(fmap_real)
loss += feat_loss

loss /= len(fmaps_real)

return loss


class RelativeFeatureMatchingLoss(Loss):
"""
Relative feature matching loss as described in https://arxiv.org/pdf/2210.13438.pdf.
This is similar to standard feature matching loss, but it scales the loss by the absolute value of
each feature averaged across time. This might be slightly different from the paper which says the
"mean is computed over all dimensions", which could imply taking the average across both time and
features.
Args:
div_guard: Value to add when dividing by mean to avoid large/NaN values.
"""

def __init__(self, div_guard=1e-3):
super(RelativeFeatureMatchingLoss, self).__init__()
self.div_guard = div_guard
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning import Trainer

from nemo.collections.tts.losses.audio_codec_loss import (
FeatureMatchingLoss,
MultiResolutionMelLoss,
MultiResolutionSTFTLoss,
RelativeFeatureMatchingLoss,
Expand Down Expand Up @@ -116,7 +117,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0)
self.gen_loss_fn = instantiate(cfg.generator_loss)
self.disc_loss_fn = instantiate(cfg.discriminator_loss)
self.feature_loss_fn = RelativeFeatureMatchingLoss()

feature_loss_type = cfg.get("feature_loss_type", "relative")
if feature_loss_type == "relative":
self.feature_loss_fn = RelativeFeatureMatchingLoss()
elif feature_loss_type == "absolute":
self.feature_loss_fn = FeatureMatchingLoss()

# Codebook loss setup
if self.vector_quantizer:
Expand Down
137 changes: 135 additions & 2 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import Iterable, Optional, Tuple

import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from nemo.collections.asr.parts.utils.activations import Snake
from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor
from nemo.core.classes.common import typecheck
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types.elements import LengthsType, VoidType
from nemo.core.neural_types.elements import AudioSignal, LengthsType, VoidType
from nemo.core.neural_types.neural_type import NeuralType


Expand Down Expand Up @@ -181,3 +183,134 @@ def remove_weight_norm(self):
@typecheck()
def forward(self, inputs):
return self.conv(inputs)


class PeriodDiscriminator(NeuralModule):
def __init__(self, period):
super().__init__()
self.period = period
self.activation = nn.LeakyReLU(0.1)
self.conv_layers = nn.ModuleList(
[
Conv2dNorm(1, 32, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(32, 128, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(128, 512, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(512, 1024, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(1024, 1024, kernel_size=(5, 1), stride=(1, 1)),
]
)
self.conv_post = Conv2dNorm(1024, 1, kernel_size=(3, 1))

@property
def input_types(self):
return {
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
}

@property
def output_types(self):
return {
"score": NeuralType(('B', 'D', 'T'), VoidType()),
"fmap": [NeuralType(("B", "C", "H", "W"), VoidType())],
}

@typecheck()
def forward(self, audio):
# Pad audio
batch_size, time = audio.shape
out = rearrange(audio, 'B T -> B 1 T')
if time % self.period != 0:
n_pad = self.period - (time % self.period)
out = F.pad(out, (0, n_pad), "reflect")
time = time + n_pad
out = out.view(batch_size, 1, time // self.period, self.period)

fmap = []
for conv in self.conv_layers:
out = conv(inputs=out)
out = self.activation(out)
fmap.append(out)
score = self.conv_post(inputs=out)
fmap.append(score)
score = rearrange(score, "B 1 T C -> B C T")

return score, fmap


class MultiPeriodDiscriminator(NeuralModule):
def __init__(self, periods: Iterable[int] = (2, 3, 5, 7, 11)):
super().__init__()
self.discriminators = nn.ModuleList([PeriodDiscriminator(period) for period in periods])

@property
def input_types(self):
return {
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
}

@property
def output_types(self):
return {
"scores_real": [NeuralType(('B', 'C', 'T'), VoidType())],
"scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())],
"fmaps_real": [[NeuralType(('B', 'D', 'H', 'C'), VoidType())]],
"fmaps_gen": [[NeuralType(('B', 'D', 'H', 'C'), VoidType())]],
}

@typecheck()
def forward(self, audio_real, audio_gen):
scores_real = []
scores_gen = []
fmaps_real = []
fmaps_gen = []
for discriminator in self.discriminators:
score_real, fmap_real = discriminator(audio=audio_real)
score_gen, fmap_gen = discriminator(audio=audio_gen)
scores_real.append(score_real)
fmaps_real.append(fmap_real)
scores_gen.append(score_gen)
fmaps_gen.append(fmap_gen)

return scores_real, scores_gen, fmaps_real, fmaps_gen


class Discriminator(NeuralModule):
"""
Wrapper class which takes a list of discriminators and aggregates the results across them.
"""

def __init__(self, discriminators: Iterable[NeuralModule]):
super().__init__()
self.discriminators = nn.ModuleList(discriminators)

@property
def input_types(self):
return {
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
}

@property
def output_types(self):
return {
"scores_real": [NeuralType(('B', 'C', 'T'), VoidType())],
"scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())],
"fmaps_real": [[NeuralType(('B', 'D', 'H', 'C'), VoidType())]],
"fmaps_gen": [[NeuralType(('B', 'D', 'H', 'C'), VoidType())]],
}

@typecheck()
def forward(self, audio_real, audio_gen):
scores_real = []
scores_gen = []
fmaps_real = []
fmaps_gen = []
for discriminator in self.discriminators:
score_real, score_gen, fmap_real, fmap_gen = discriminator(audio_real=audio_real, audio_gen=audio_gen)
scores_real += score_real
fmaps_real += fmap_real
scores_gen += score_gen
fmaps_gen += fmap_gen

return scores_real, scores_gen, fmaps_real, fmaps_gen

0 comments on commit 35d0e41

Please sign in to comment.