diff --git a/examples/tts/conf/audio_codec/mel_codec_44100.yaml b/examples/tts/conf/audio_codec/mel_codec_44100.yaml new file mode 100644 index 000000000000..15d12f009ae0 --- /dev/null +++ b/examples/tts/conf/audio_codec/mel_codec_44100.yaml @@ -0,0 +1,196 @@ +# This config contains the default values for training 44.1kHz audio codec model which encodes mel spectrogram +# instead of raw audio. +# If you want to train model on other dataset, you can change config values according to your dataset. +# Most dataset-specific arguments are in the head of the config file, see below. + +name: MelCodec + +max_epochs: ??? +# Adjust batch size based on GPU memory +batch_size: 16 +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. +weighted_sampling_steps_per_epoch: null + +# Dataset metadata for each manifest +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +train_ds_meta: ??? +val_ds_meta: ??? + +log_ds_meta: ??? +log_dir: ??? + +# Modify these values based on your sample rate +sample_rate: 44100 +win_length: 2048 +hop_length: 512 +train_n_samples: 16384 # ~0.37 seconds +# The product of the up_sample_rates should match the hop_length. +# For example 8 * 8 * 4 * 2 = 512. +up_sample_rates: [8, 8, 4, 2] + + +model: + + max_epochs: ${max_epochs} + steps_per_epoch: ${weighted_sampling_steps_per_epoch} + + sample_rate: ${sample_rate} + samples_per_frame: ${hop_length} + + mel_loss_l1_scale: 1.0 + mel_loss_l2_scale: 0.0 + stft_loss_scale: 20.0 + time_domain_loss_scale: 0.0 + commit_loss_scale: 0.0 + + # Probability of updating the discriminator during each training step + # For example, update the discriminator 1/2 times (1 update for every 2 batches) + disc_updates_per_period: 1 + disc_update_period: 2 + + # All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length] + loss_resolutions: [ + [32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048] + ] + mel_loss_dims: [5, 10, 20, 40, 80, 160, 320] + mel_loss_log_guard: 1.0 + stft_loss_log_guard: 1.0 + feature_loss_type: absolute + + train_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + dataset_meta: ${train_ds_meta} + weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} + sample_rate: ${sample_rate} + n_samples: ${train_n_samples} + min_duration: 0.4 + max_duration: null + + dataloader_params: + batch_size: ${batch_size} + drop_last: true + num_workers: 4 + + validation_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss + dataset_meta: ${val_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + # Configures how audio samples are generated and saved during training. + # Remove this section to disable logging. + log_config: + log_dir: ${log_dir} + log_epochs: [10, 50, 100, 150, 200] + epoch_frequency: 100 + log_tensorboard: false + log_wandb: true + + generators: + - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator + log_audio: true + log_encoding: true + log_dequantized: true + + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 # Only log the first 10 seconds of generated audio. + dataset_meta: ${log_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + audio_encoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.MultiBandMelEncoder + mel_bands: [[0, 10], [10, 20], [20, 30], [30, 40], [40, 50], [50, 60], [60, 70], [70, 80]] + out_channels: 4 # The dimension of each codebook + hidden_channels: 128 + filters: 256 + mel_processor: + _target_: nemo.collections.tts.modules.audio_codec_modules.MelSpectrogramProcessor + mel_dim: 80 + sample_rate: ${sample_rate} + win_length: ${win_length} + hop_length: ${hop_length} + + audio_decoder: + _target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder + up_sample_rates: ${up_sample_rates} + input_dim: 32 # Should be equal to len(audio_encoder.mel_bands) * audio_encoder.out_channels + base_channels: 1024 # This is double the base channels of HiFi-GAN V1, making it approximately 4x larger. + + vector_quantizer: + _target_: nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer + num_groups: 8 # Should equal len(audio_encoder.mel_bands) + num_levels_per_group: [8, 5, 5, 5] # 8 * 5 * 5 * 5 = 1000 entries per codebook + + discriminator: + _target_: nemo.collections.tts.modules.audio_codec_modules.Discriminator + discriminators: + - _target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT + resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] + - _target_: nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator + + # The original EnCodec uses hinged loss, but squared-GAN loss is more stable + # and reduces the need to tune the loss weights or use a gradient balancer. + generator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss + + discriminator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss + + optim: + _target_: torch.optim.Adam + lr: 2e-4 + betas: [0.8, 0.99] + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: 1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: 16 + max_epochs: ${max_epochs} + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 5 + benchmark: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: false + create_wandb_logger: true + wandb_logger_kwargs: + name: null + project: null + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index 933d09a67fa1..96029d9bd105 100644 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -15,16 +15,25 @@ from abc import ABC, abstractmethod from typing import Iterable, List, Optional, Tuple +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor from nemo.collections.asr.parts.utils.activations import Snake from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor from nemo.core.classes.common import typecheck from nemo.core.classes.module import NeuralModule -from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType, VoidType +from nemo.core.neural_types.elements import ( + AudioSignal, + EncodedRepresentation, + Index, + LengthsType, + MelSpectrogramType, + VoidType, +) from nemo.core.neural_types.neural_type import NeuralType from nemo.utils import logging @@ -50,16 +59,22 @@ def get_up_sample_padding(kernel_size: int, stride: int) -> Tuple[int, int]: class CodecActivation(nn.Module): """ - Choose between snake or Elu activation based on the input parameter. + Choose between activation based on the input parameter. + + Args: + activation: Name of activation to use. Valid options are "elu" (default), "lrelu", and "snake". + channels: Input dimension. """ def __init__(self, activation: str = "elu", channels: int = 1): super().__init__() activation = activation.lower() - if activation == "snake": - self.activation = Snake(channels) - elif activation == "elu": + if activation == "elu": self.activation = nn.ELU() + elif activation == "lrelu": + self.activation = torch.nn.LeakyReLU() + elif activation == "snake": + self.activation = Snake(channels) else: raise ValueError(f"Unknown activation {activation}") @@ -69,17 +84,24 @@ def forward(self, x): class Conv1dNorm(NeuralModule): def __init__( - self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: Optional[int] = None, + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + padding: Optional[int] = None, ): super().__init__() if not padding: - padding = get_padding(kernel_size) + padding = get_padding(kernel_size=kernel_size, dilation=dilation) conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, padding_mode="reflect", ) self.conv = nn.utils.weight_norm(conv) @@ -698,3 +720,490 @@ def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor dequantized = torch.cat(dequantized, dim=1) return dequantized + + +class ResidualBlock(NeuralModule): + """ + The residual block structure defined by the HiFi-GAN V1 and V2 configurations. + + Args: + channels: Input dimension. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + dilation: Dilation of the residual convolutions. + dropout_rate: Dropout to apply to residuals. + activation: Activation to apply in between residual convolutions. + """ + + def __init__( + self, + channels: int, + filters: int, + kernel_size: int = 3, + dilation: int = 1, + dropout_rate: float = 0.0, + activation: str = "lrelu", + ): + super(ResidualBlock, self).__init__() + + self.input_activation = CodecActivation(activation=activation, channels=channels) + self.skip_activation = CodecActivation(activation=activation, channels=filters) + self.dropout = torch.nn.Dropout(dropout_rate) + self.input_conv = Conv1dNorm( + in_channels=channels, out_channels=filters, kernel_size=kernel_size, dilation=dilation + ) + self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} + + @property + def output_types(self): + return {"out": NeuralType(('B', 'C', 'T'), EncodedRepresentation())} + + @typecheck() + def forward(self, inputs, input_len): + conv_input = self.input_activation(inputs) + skip_input = self.input_conv(inputs=conv_input, input_len=input_len) + skip_input = self.skip_activation(skip_input) + res = self.skip_conv(inputs=skip_input, input_len=input_len) + res = self.dropout(res) + out = inputs + res + return out + + +class HiFiGANResBlock(NeuralModule): + """ + Residual block wrapper for HiFi-GAN which creates a block for multiple dilations. + + Args: + channels: Input dimension. + kernel_size: Kernel size of the residual blocks. + dilations: List of dilations. One residual block will be created for each dilation in the list. + activation: Activation for the residual blocks. + """ + + def __init__(self, channels: int, kernel_size: int, dilations: Iterable[int], activation: str): + super().__init__() + + self.res_blocks = nn.ModuleList( + [ + ResidualBlock( + channels=channels, + filters=channels, + kernel_size=kernel_size, + dilation=dilation, + activation=activation, + ) + for dilation in dilations + ] + ) + + def remove_weight_norm(self): + for res_block in self.res_blocks: + res_block.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return {"out": NeuralType(('B', 'C', 'T'), VoidType())} + + @typecheck() + def forward(self, inputs, input_len): + out = inputs + for res_block in self.res_blocks: + out = res_block(inputs=out, input_len=input_len) + return out + + +class HiFiGANResLayer(NeuralModule): + """ + Residual block wrapper for HiFi-GAN which creates a block for multiple kernel sizes and dilations. + One residual block is created for each combination of kernel size and dilation. + + Args: + channels: Input dimension. + kernel_sizes: List of kernel sizes. + dilations: List of dilations. + activation: Activation for the residual layers. + + """ + + def __init__(self, channels: int, kernel_sizes: Iterable[int], dilations: Iterable[int], activation: str): + super().__init__() + + self.res_blocks = nn.ModuleList( + [ + HiFiGANResBlock(channels=channels, kernel_size=kernel_size, dilations=dilations, activation=activation) + for kernel_size in kernel_sizes + ] + ) + + def remove_weight_norm(self): + for res_block in self.res_blocks: + res_block.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return {"out": NeuralType(('B', 'D', 'T'), VoidType())} + + @typecheck() + def forward(self, inputs, input_len): + residuals = [res_block(inputs=inputs, input_len=input_len) for res_block in self.res_blocks] + out = sum(residuals) / len(residuals) + return out + + +class HiFiGANDecoder(NeuralModule): + """ + Codec decoder using the HiFi-GAN generator architecture. + + Default parameters match the HiFi-GAN V1 configuration for 22.05khz. + + Args: + input_dim: Input dimension. + up_sample_rates: Rate to upsample for each decoder block. The product of the upsample rates will + determine the output frame rate. For example 8 * 8 * 2 * 2 = 256 samples per token. + base_channels: Number of filters in the first convolution. The number of channels will be cut in + half after each upsample layer. + in_kernel_size: Kernel size of the input convolution. + out_kernel_size: Kernel size of the output convolution. + resblock_kernel_sizes: List of kernel sizes to use in each residual block. + resblock_dilation_sizes: List of dilations to use in each residual block. + activation: Activation to use in residual and upsample layers, defaults to leaky relu. + """ + + def __init__( + self, + input_dim: int, + up_sample_rates: Iterable[int] = (8, 8, 2, 2), + base_channels: int = 512, + in_kernel_size: int = 7, + out_kernel_size: int = 3, + resblock_kernel_sizes: Iterable[int] = (3, 7, 11), + resblock_dilation_sizes: Iterable[int] = (1, 3, 5), + activation: str = "lrelu", + ): + assert in_kernel_size > 0 + assert out_kernel_size > 0 + + super().__init__() + + self.up_sample_rates = up_sample_rates + self.pre_conv = Conv1dNorm(in_channels=input_dim, out_channels=base_channels, kernel_size=in_kernel_size) + + in_channels = base_channels + self.activations = nn.ModuleList([]) + self.up_sample_conv_layers = nn.ModuleList([]) + self.res_layers = nn.ModuleList([]) + for i, up_sample_rate in enumerate(self.up_sample_rates): + out_channels = in_channels // 2 + kernel_size = 2 * up_sample_rate + + act = CodecActivation(activation, channels=in_channels) + self.activations.append(act) + + up_sample_conv = ConvTranspose1dNorm( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=up_sample_rate + ) + in_channels = out_channels + self.up_sample_conv_layers.append(up_sample_conv) + + res_layer = HiFiGANResLayer( + channels=in_channels, + kernel_sizes=resblock_kernel_sizes, + dilations=resblock_dilation_sizes, + activation=activation, + ) + self.res_layers.append(res_layer) + + self.post_activation = CodecActivation(activation, channels=in_channels) + self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size) + self.out_activation = nn.Tanh() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T_encoded'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + for up_sample_conv in self.up_sample_conv_layers: + up_sample_conv.remove_weight_norm() + for res_layer in self.res_layers: + res_layer.remove_weight_norm() + + @typecheck() + def forward(self, inputs, input_len): + audio_len = input_len + # [B, C, T_encoded] + out = self.pre_conv(inputs=inputs, input_len=audio_len) + for act, res_layer, up_sample_conv, up_sample_rate in zip( + self.activations, self.res_layers, self.up_sample_conv_layers, self.up_sample_rates + ): + audio_len = audio_len * up_sample_rate + out = act(out) + # [B, C / 2, T * up_sample_rate] + out = up_sample_conv(inputs=out, input_len=audio_len) + out = res_layer(inputs=out, input_len=audio_len) + + out = self.post_activation(out) + # [B, 1, T_audio] + out = self.post_conv(inputs=out, input_len=audio_len) + audio = self.out_activation(out) + audio = rearrange(audio, "B 1 T -> B T") + return audio, audio_len + + +class MelSpectrogramProcessor(NeuralModule): + """ + Wrapper interface for computing mel spectrogram for codec training. + """ + + def __init__(self, sample_rate: int, win_length: int, hop_length: int, mel_dim: int = 80, log_guard: float = 1.0): + super(MelSpectrogramProcessor, self).__init__() + self.mel_dim = mel_dim + self.hop_length = hop_length + self.preprocessor = AudioToMelSpectrogramPreprocessor( + sample_rate=sample_rate, + highfreq=None, + features=mel_dim, + pad_to=1, + exact_pad=True, + n_window_size=win_length, + n_window_stride=hop_length, + window_size=False, + window_stride=False, + n_fft=win_length, + mag_power=1.0, + log=True, + log_zero_guard_type="add", + log_zero_guard_value=log_guard, + mel_norm=None, + normalize=None, + preemph=None, + dither=0.0, + ) + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), + "spec_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + spec, spec_len = self.preprocessor(input_signal=audio, length=audio_len) + return spec, spec_len + + +class ResNetEncoder(NeuralModule): + """ + Residual network which uses HiFi-GAN residual blocks to encode spectrogram features without changing + the time dimension. + + Args: + in_channels: input dimension + out_channels: output dimension + num_layers: number of residual blocks to use + hidden_channels: encoder hidden dimension + filters: number of filters in residual block layers + kernel_size: kernel size in residual block convolutions + dropout_rate: Optional dropout rate to apply to residuals. + activation: Activation to use, defaults to leaky relu. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 6, + hidden_channels: int = 256, + filters: int = 768, + kernel_size: int = 3, + dropout_rate: float = 0.1, + activation: str = "lrelu", + ): + super(ResNetEncoder, self).__init__() + + self.pre_conv = Conv1dNorm(in_channels=in_channels, out_channels=hidden_channels, kernel_size=kernel_size) + self.res_layers = nn.ModuleList( + [ + ResidualBlock( + channels=hidden_channels, + filters=filters, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + activation=activation, + ) + for _ in range(num_layers) + ] + ) + self.post_activation = CodecActivation(activation, channels=hidden_channels) + self.post_conv = Conv1dNorm(in_channels=hidden_channels, out_channels=out_channels, kernel_size=kernel_size) + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + self.post_conv.remove_weight_norm() + for res_layer in self.res_layers: + res_layer.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return {"encoded": NeuralType(('B', 'C', 'T'), EncodedRepresentation())} + + @typecheck() + def forward(self, inputs, input_len): + encoded = self.pre_conv(inputs=inputs, input_len=input_len) + for res_layer in self.res_layers: + encoded = res_layer(inputs=encoded, input_len=input_len) + encoded = self.post_activation(encoded) + encoded = self.post_conv(inputs=encoded, input_len=input_len) + return encoded + + +class FullBandMelEncoder(NeuralModule): + """ + Encoder which encodes the entire mel spectrogram with a single encoder network. + + Args: + mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from + input audio. + encoder: ResNetEncoder or equivalent class for encoding the mel spectrogram. + """ + + def __init__(self, mel_processor: NeuralModule, encoder: NeuralModule): + super(FullBandMelEncoder, self).__init__() + self.mel_processor = mel_processor + self.encoder = encoder + + def remove_weight_norm(self): + self.encoder.remove_weight_norm() + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": NeuralType(('B', 'C', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + out, spec_len = self.mel_processor(audio=audio, audio_len=audio_len) + encoded = self.encoder(inputs=out, input_len=spec_len) + return encoded, spec_len + + +class MultiBandMelEncoder(NeuralModule): + """ + Encoder which splits mel spectrogram into bands and encodes each using separate residual networks. + + Args: + mel_bands: List of mel spectrogram bands to encode. + Each list element is tuple of 2 elements with the start and end index of the mel features to use. + mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from + input audio. + encoder_kwargs: Arguments for constructing encoder for each mel band. + """ + + def __init__(self, mel_bands: Iterable[Tuple[int, int]], mel_processor: NeuralModule, **encoder_kwargs): + super(MultiBandMelEncoder, self).__init__() + self.validate_mel_bands(mel_dim=mel_processor.mel_dim, mel_bands=mel_bands) + self.mel_bands = mel_bands + self.mel_processor = mel_processor + band_dims = [band[1] - band[0] for band in self.mel_bands] + self.encoders = nn.ModuleList( + [ResNetEncoder(in_channels=band_dim, **encoder_kwargs) for band_dim in band_dims] + ) + + @staticmethod + def validate_mel_bands(mel_dim: int, mel_bands: Iterable[Tuple[int, int]]): + mel_dims_used = np.zeros([mel_dim], dtype=bool) + for band in mel_bands: + mel_dims_used[band[0] : band[1]] = True + + if not all(mel_dims_used): + missing_dims = np.where(~mel_dims_used) + raise ValueError(f"Mel bands must cover all {mel_dim} dimensions. Missing {missing_dims}.") + + return + + def remove_weight_norm(self): + for encoder in self.encoders: + encoder.remove_weight_norm() + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": NeuralType(('B', 'C', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + spec, spec_len = self.mel_processor(audio=audio, audio_len=audio_len) + outputs = [] + for (band_start, band_end), encoder in zip(self.mel_bands, self.encoders): + # [B, D_band, T] + spec_band = spec[:, band_start:band_end, :] + band_out = encoder(inputs=spec_band, input_len=spec_len) + outputs.append(band_out) + # [B, C, T] + encoded = torch.cat(outputs, dim=1) + return encoded, spec_len diff --git a/tests/collections/tts/modules/test_audio_codec_modules.py b/tests/collections/tts/modules/test_audio_codec_modules.py index 6d4f88c0f417..28de02b6afb4 100644 --- a/tests/collections/tts/modules/test_audio_codec_modules.py +++ b/tests/collections/tts/modules/test_audio_codec_modules.py @@ -21,15 +21,22 @@ ConvTranspose1dNorm, FiniteScalarQuantizer, GroupFiniteScalarQuantizer, + HiFiGANDecoder, + MelSpectrogramProcessor, + MultiBandMelEncoder, + ResidualBlock, + ResNetEncoder, get_down_sample_padding, ) from nemo.collections.tts.modules.encodec_modules import GroupResidualVectorQuantizer, ResidualVectorQuantizer +from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor class TestAudioCodecModules: def setup_class(self): self.in_channels = 8 self.out_channels = 16 + self.filters = 32 self.batch_size = 2 self.len1 = 4 self.len2 = 8 @@ -98,6 +105,103 @@ def test_conv1d_transpose_upsample(self): assert torch.all(out[1, :, :out_len_2] != 0.0) assert torch.all(out[1, :, out_len_2:] == 0.0) + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_residual_block(self): + lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) + inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) + inputs = mask_sequence_tensor(tensor=inputs, lengths=lengths) + + res_block = ResidualBlock(channels=self.in_channels, filters=self.filters) + out = res_block(inputs=inputs, input_len=lengths) + + assert out.shape == (self.batch_size, self.in_channels, self.max_len) + assert torch.all(out[0, :, : self.len1] != 0.0) + assert torch.all(out[0, :, self.len1 :] == 0.0) + assert torch.all(out[1, :, : self.len2] != 0.0) + assert torch.all(out[1, :, self.len2 :] == 0.0) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_hifigan_decoder(self): + up_sample_rates = [4, 4, 2, 2] + up_sample_total = 64 + lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) + out_len_1 = self.len1 * up_sample_total + out_len_2 = self.len2 * up_sample_total + out_len_max = self.max_len * up_sample_total + + inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) + inputs = mask_sequence_tensor(tensor=inputs, lengths=lengths) + + decoder = HiFiGANDecoder( + input_dim=self.in_channels, base_channels=self.filters, up_sample_rates=up_sample_rates + ) + out, out_len = decoder(inputs=inputs, input_len=lengths) + + assert out_len[0] == out_len_1 + assert out_len[1] == out_len_2 + assert out.shape == (self.batch_size, out_len_max) + assert torch.all(out[0, :out_len_1] != 0.0) + assert torch.all(out[0, out_len_1:] == 0.0) + assert torch.all(out[1, :out_len_2] != 0.0) + assert torch.all(out[1, out_len_2:] == 0.0) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_resnet_encoder(self): + lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) + inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) + inputs = mask_sequence_tensor(tensor=inputs, lengths=lengths) + + res_net = ResNetEncoder(in_channels=self.in_channels, out_channels=self.out_channels) + out = res_net(inputs=inputs, input_len=lengths) + + assert out.shape == (self.batch_size, self.out_channels, self.max_len) + assert torch.all(out[0, :, : self.len1] != 0.0) + assert torch.all(out[0, :, self.len1 :] == 0.0) + assert torch.all(out[1, :, : self.len2] != 0.0) + assert torch.all(out[1, :, self.len2 :] == 0.0) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_multiband_mel_encoder(self): + mel_dim = 10 + win_length = 16 + hop_length = 10 + mel_bands = [(0, 4), (4, 7), (7, 10)] + max_len = 100 + len1 = 40 + len2 = 80 + out_dim = len(mel_bands) * self.out_channels + lengths = torch.tensor([len1, len2], dtype=torch.int32) + out_len_1 = len1 // hop_length + out_len_2 = len2 // hop_length + out_len_max = max_len // hop_length + + audio = torch.rand([self.batch_size, max_len]) + audio = mask_sequence_tensor(tensor=audio, lengths=lengths) + + mel_processor = MelSpectrogramProcessor( + mel_dim=mel_dim, sample_rate=100, win_length=win_length, hop_length=hop_length + ) + encoder = MultiBandMelEncoder( + mel_bands=mel_bands, + mel_processor=mel_processor, + out_channels=self.out_channels, + hidden_channels=self.filters, + filters=self.filters, + ) + out, out_len = encoder(audio=audio, audio_len=lengths) + + assert out_len[0] == out_len_1 + assert out_len[1] == out_len_2 + assert out.shape == (self.batch_size, out_dim, out_len_max) + assert torch.all(out[0, :, :out_len_1] != 0.0) + assert torch.all(out[0, :, out_len_1:] == 0.0) + assert torch.all(out[1, :, :out_len_2] != 0.0) + assert torch.all(out[1, :, out_len_2:] == 0.0) + class TestResidualVectorQuantizer: def setup_class(self):