Skip to content

Commit

Permalink
[TTS] Update docs for period discriminator
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed Jan 11, 2024
1 parent 776258e commit 147197f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
2 changes: 2 additions & 0 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.feature_loss_fn = RelativeFeatureMatchingLoss()
elif feature_loss_type == "absolute":
self.feature_loss_fn = FeatureMatchingLoss()
else:
raise ValueError(f'Unknown feature loss type {feature_loss_type}.')

# Codebook loss setup
if self.vector_quantizer:
Expand Down
54 changes: 38 additions & 16 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,20 @@ def forward(self, inputs):


class PeriodDiscriminator(NeuralModule):
def __init__(self, period):
"""
Period discriminator introduced in HiFi-GAN https://arxiv.org/abs/2010.05646 which attempts to
discriminate phase information by looking at equally spaced audio samples.
Args:
period: Spacing between audio sample inputs.
lrelu_slope: Slope to use for activation. Leaky relu with slope of 0.1 or 0.2 is recommended for the
stability of the feature matching loss.
"""

def __init__(self, period, lrelu_slope=0.1):
super().__init__()
self.period = period
self.activation = nn.LeakyReLU(0.1)
self.activation = nn.LeakyReLU(lrelu_slope)
self.conv_layers = nn.ModuleList(
[
Conv2dNorm(1, 32, kernel_size=(5, 1), stride=(3, 1)),
Expand All @@ -212,26 +222,30 @@ def input_types(self):
@property
def output_types(self):
return {
"score": NeuralType(('B', 'D', 'T'), VoidType()),
"fmap": [NeuralType(('B', 'C', 'H', 'W'), VoidType())],
"score": NeuralType(('B', 'C', 'T_out'), VoidType()),
"fmap": [NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())],
}

@typecheck()
def forward(self, audio):
# Pad audio

batch_size, time = audio.shape
out = rearrange(audio, 'B T -> B 1 T')
# Pad audio so that it is divisible by the period
if time % self.period != 0:
n_pad = self.period - (time % self.period)
out = F.pad(out, (0, n_pad), "reflect")
time = time + n_pad
# [batch, 1, (time / period), period]
out = out.view(batch_size, 1, time // self.period, self.period)

fmap = []
for conv in self.conv_layers:
# [batch, filters, (time / period / stride), period]
out = conv(inputs=out)
out = self.activation(out)
fmap.append(out)
# [batch, 1, (time / period / strides), period]
score = self.conv_post(inputs=out)
fmap.append(score)
score = rearrange(score, "B 1 T C -> B C T")
Expand All @@ -240,9 +254,17 @@ def forward(self, audio):


class MultiPeriodDiscriminator(NeuralModule):
def __init__(self, periods: Iterable[int] = (2, 3, 5, 7, 11)):
"""
Wrapper class to aggregate results of multiple period discriminators.
The periods are expected to be increasing prime numbers in order to maximize coverage and minimize overlap
"""

def __init__(self, periods: Iterable[int] = (2, 3, 5, 7, 11), lrelu_slope=0.1):
super().__init__()
self.discriminators = nn.ModuleList([PeriodDiscriminator(period) for period in periods])
self.discriminators = nn.ModuleList(
[PeriodDiscriminator(period=period, lrelu_slope=lrelu_slope) for period in periods]
)

@property
def input_types(self):
Expand All @@ -254,10 +276,10 @@ def input_types(self):
@property
def output_types(self):
return {
"scores_real": [NeuralType(('B', 'C', 'T'), VoidType())],
"scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())],
"fmaps_real": [[NeuralType(('B', 'D', 'H', 'C'), VoidType())]],
"fmaps_gen": [[NeuralType(('B', 'D', 'H', 'C'), VoidType())]],
"scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
"fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
}

@typecheck()
Expand Down Expand Up @@ -296,10 +318,10 @@ def input_types(self):
@property
def output_types(self):
return {
"scores_real": [NeuralType(('B', 'C', 'T'), VoidType())],
"scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())],
"fmaps_real": [[NeuralType(('B', 'D', 'H', 'C'), VoidType())]],
"fmaps_gen": [[NeuralType(('B', 'D', 'H', 'C'), VoidType())]],
"scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
"fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
}

@typecheck()
Expand Down Expand Up @@ -659,4 +681,4 @@ def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor
# concatenate along the feature dimension
dequantized = torch.cat(dequantized, dim=1)

return dequantized
return dequantized

0 comments on commit 147197f

Please sign in to comment.