Skip to content

Commit

Permalink
[TTS] Add mel band validation
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed Feb 12, 2024
1 parent 4b6843e commit b72fd08
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# This config contains the default values for training 24khz audio codec model
# This config contains the default values for training 44.1kHz audio codec model which encodes mel spectrogram
# instead of raw audio.
# If you want to train model on other dataset, you can change config values according to your dataset.
# Most dataset-specific arguments are in the head of the config file, see below.

name: SpeechCodec
name: MelCodec

max_epochs: ???
# Adjust batch size based on GPU memory
Expand Down Expand Up @@ -44,7 +45,7 @@ model:
commit_loss_scale: 0.0

# Probability of updating the discriminator during each training step
# For example, update the discriminator 1/2 times (1 updates for every 2 batches)
# For example, update the discriminator 1/2 times (1 update for every 2 batches)
disc_updates_per_period: 1
disc_update_period: 2

Expand Down Expand Up @@ -128,13 +129,13 @@ model:
audio_decoder:
_target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder
up_sample_rates: ${up_sample_rates}
input_dim: 32
input_dim: 32 # Should be equal to len(audio_encoder.mel_bands) * audio_encoder.out_channels
base_channels: 1024 # This is double the base channels of HiFi-GAN V1, making it approximately 4x larger.

vector_quantizer:
_target_: nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer
num_groups: 8
num_levels_per_group: [8, 5, 5, 5] # 8 x 5 x 5 x 5 = 1000 entries per codebook
num_groups: 8 # Should equal len(audio_encoder.mel_bands)
num_levels_per_group: [8, 5, 5, 5] # 8 * 5 * 5 * 5 = 1000 entries per codebook

discriminator:
_target_: nemo.collections.tts.modules.audio_codec_modules.Discriminator
Expand Down
15 changes: 15 additions & 0 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from abc import ABC, abstractmethod
from typing import Iterable, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -962,6 +963,7 @@ class MelSpectrogramProcessor(NeuralModule):

def __init__(self, sample_rate: int, win_length: int, hop_length: int, mel_dim: int = 80, log_guard: float = 1.0):
super(MelSpectrogramProcessor, self).__init__()
self.mel_dim = mel_dim
self.hop_length = hop_length
self.preprocessor = AudioToMelSpectrogramPreprocessor(
sample_rate=sample_rate,
Expand Down Expand Up @@ -1090,13 +1092,26 @@ class MultiBandMelEncoder(NeuralModule):

def __init__(self, mel_bands: Iterable[Tuple[int, int]], mel_processor: NeuralModule, **encoder_kwargs):
super(MultiBandMelEncoder, self).__init__()
self.validate_mel_bands(mel_dim=mel_processor.mel_dim, mel_bands=mel_bands)
self.mel_bands = mel_bands
self.mel_processor = mel_processor
band_dims = [band[1] - band[0] for band in self.mel_bands]
self.encoders = nn.ModuleList(
[ResNetEncoder(in_channels=band_dim, **encoder_kwargs) for band_dim in band_dims]
)

@staticmethod
def validate_mel_bands(mel_dim: int, mel_bands: Iterable[Tuple[int, int]]):
mel_dims_used = np.zeros([mel_dim], dtype=bool)
for band in mel_bands:
mel_dims_used[band[0] : band[1]] = True

if not all(mel_dims_used):
missing_dims = np.where(~mel_dims_used)
raise ValueError(f"Mel bands must cover all {mel_dim} dimensions. Missing {missing_dims}.")

return

def remove_weight_norm(self):
for encoder in self.encoders:
encoder.remove_weight_norm()
Expand Down

0 comments on commit b72fd08

Please sign in to comment.