Skip to content

Commit

Permalink
[TTS] Create EnCodec training recipe (#6852)
Browse files Browse the repository at this point in the history
* [TTS] Create EnCodec training recipe

Signed-off-by: Ryan <[email protected]>

* [TTS] Update encodec recipe

Signed-off-by: Ryan <[email protected]>

* [TTS] Rename EnCodec to AudioCodec

Signed-off-by: Ryan <[email protected]>

* [TTS] Add EnCodec unit tests

Signed-off-by: Ryan <[email protected]>

* [TTS] Add copyright header to distributed.py

Signed-off-by: Ryan <[email protected]>

---------

Signed-off-by: Ryan <[email protected]>
Signed-off-by: jubick1337 <[email protected]>
  • Loading branch information
rlangman authored and jubick1337 committed Aug 8, 2023
1 parent 0b48771 commit eea78b2
Show file tree
Hide file tree
Showing 15 changed files with 2,128 additions and 1 deletion.
31 changes: 31 additions & 0 deletions examples/tts/audio_codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytorch_lightning as pl

from nemo.collections.tts.models import AudioCodecModel
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf/audio_codec", config_name="audio_codec")
def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = AudioCodecModel(cfg=cfg.model, trainer=trainer)
trainer.fit(model)


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
160 changes: 160 additions & 0 deletions examples/tts/conf/audio_codec/encodec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# This config contains the default values for training 24khz EnCodec model
# If you want to train model on other dataset, you can change config values according to your dataset.
# Most dataset-specific arguments are in the head of the config file, see below.

name: EnCodec

max_epochs: ???
# Adjust batch size based on GPU memory
batch_size: 16
# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch.
# If null, then weighted sampling is disabled.
weighted_sampling_steps_per_epoch: null

# Dataset metadata for each manifest
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41
train_ds_meta: ???
val_ds_meta: ???

log_ds_meta: ???
log_dir: ???

# Modify these values based on your sample rate
sample_rate: 24000
train_n_samples: 24000
down_sample_rates: [2, 4, 5, 8]
up_sample_rates: [8, 5, 4, 2]
# The number of samples per encoded audio frame. Should be the product of the down_sample_rates.
# For example 2 * 4 * 5 * 8 = 320.
samples_per_frame: 320

model:

max_epochs: ${max_epochs}
steps_per_epoch: ${weighted_sampling_steps_per_epoch}

sample_rate: ${sample_rate}
samples_per_frame: ${samples_per_frame}
time_domain_loss_scale: 0.1
# Probability of updating the discriminator during each training step
disc_update_prob: 0.67

# All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length]
mel_loss_resolutions: [
[32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]
]

train_ds:
dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch}
sample_rate: ${sample_rate}
n_samples: ${train_n_samples}
min_duration: 1.01
max_duration: null
dataset_meta: ${train_ds_meta}

dataloader_params:
batch_size: ${batch_size}
drop_last: true
num_workers: 4

validation_ds:
dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
sample_rate: ${sample_rate}
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss
dataset_meta: ${val_ds_meta}

dataloader_params:
batch_size: 8
num_workers: 2

# Configures how audio samples are generated and saved during training.
# Remove this section to disable logging.
log_config:
log_dir: ${log_dir}
log_epochs: [10, 50]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: true
log_quantized: true

dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
sample_rate: ${sample_rate}
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 15.0 # Only log the first 15 seconds of generated audio.
dataset_meta: ${log_ds_meta}

dataloader_params:
batch_size: 4
num_workers: 2

audio_encoder:
_target_: nemo.collections.tts.modules.audio_codec_modules.SEANetEncoder
down_sample_rates: ${down_sample_rates}

audio_decoder:
_target_: nemo.collections.tts.modules.audio_codec_modules.SEANetDecoder
up_sample_rates: ${up_sample_rates}

vector_quantizer:
_target_: nemo.collections.tts.modules.vector_quantization.ResidualVectorQuantizer
num_codebooks: 8

discriminator:
_target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT
resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]]

# The original EnCodec uses hinged loss, but squared-GAN loss is more stable
# and reduces the need to tune the loss weights or use a gradient balancer.
generator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss

discriminator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss

optim:
_target_: torch.optim.Adam
lr: 3e-4
betas: [0.5, 0.9]

sched:
name: ExponentialLR
gamma: 0.999

trainer:
num_nodes: 1
devices: 1
accelerator: gpu
strategy: ddp
precision: 32 # Vector quantization only works with 32-bit precision.
max_epochs: ${max_epochs}
accumulate_grad_batches: 1
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 5
benchmark: false

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
create_wandb_logger: false
checkpoint_callback_params:
monitor: val_loss
resume_if_exists: false
resume_ignore_no_checkpoint: false
1 change: 1 addition & 0 deletions examples/tts/conf/hifigan/hifigan_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ model:
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 15.0
dataset_meta: ${log_ds_meta}

dataloader_params:
Expand Down
11 changes: 11 additions & 0 deletions nemo/collections/tts/data/vocoder_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class VocoderDataset(Dataset):
will be ignored.
max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration'
will be ignored.
trunc_duration: Optional int, if provided audio will be truncated to at most 'trunc_duration' seconds.
num_audio_retries: Number of read attempts to make when sampling audio file, to avoid training failing
from sporadic IO errors.
"""
Expand All @@ -78,6 +79,7 @@ def __init__(
feature_processors: Optional[Dict[str, FeatureProcessor]] = None,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
trunc_duration: Optional[float] = None,
num_audio_retries: int = 5,
):
super().__init__()
Expand All @@ -88,6 +90,11 @@ def __init__(
self.num_audio_retries = num_audio_retries
self.load_precomputed_mel = False

if trunc_duration:
self.trunc_samples = int(trunc_duration * self.sample_rate)
else:
self.trunc_samples = None

if feature_processors:
logging.info(f"Found feature processors {feature_processors.keys()}")
self.feature_processors = list(feature_processors.values())
Expand Down Expand Up @@ -132,6 +139,10 @@ def _sample_audio(self, audio_filepath: Path) -> Tuple[torch.Tensor, torch.Tenso
else:
audio_segment = self._segment_audio(audio_filepath)
audio_array = audio_segment.samples

if self.trunc_samples:
audio_array = audio_array[: self.trunc_samples]

audio = torch.tensor(audio_array)
audio_len = torch.tensor(audio.shape[0])
return audio, audio_len
Expand Down
Loading

0 comments on commit eea78b2

Please sign in to comment.