Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Create EnCodec training recipe #6852

Merged
merged 5 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/tts/audio_codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytorch_lightning as pl

from nemo.collections.tts.models import AudioCodecModel
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf/audio_codec", config_name="audio_codec")
def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = AudioCodecModel(cfg=cfg.model, trainer=trainer)
trainer.fit(model)


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
160 changes: 160 additions & 0 deletions examples/tts/conf/audio_codec/encodec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# This config contains the default values for training 24khz EnCodec model
# If you want to train model on other dataset, you can change config values according to your dataset.
# Most dataset-specific arguments are in the head of the config file, see below.

name: EnCodec

max_epochs: ???
# Adjust batch size based on GPU memory
batch_size: 16
# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch.
# If null, then weighted sampling is disabled.
weighted_sampling_steps_per_epoch: null

# Dataset metadata for each manifest
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41
train_ds_meta: ???
val_ds_meta: ???

log_ds_meta: ???
log_dir: ???

# Modify these values based on your sample rate
sample_rate: 24000
train_n_samples: 24000
down_sample_rates: [2, 4, 5, 8]
up_sample_rates: [8, 5, 4, 2]
# The number of samples per encoded audio frame. Should be the product of the down_sample_rates.
# For example 2 * 4 * 5 * 8 = 320.
samples_per_frame: 320

model:

max_epochs: ${max_epochs}
steps_per_epoch: ${weighted_sampling_steps_per_epoch}

sample_rate: ${sample_rate}
samples_per_frame: ${samples_per_frame}
time_domain_loss_scale: 0.1
# Probability of updating the discriminator during each training step
disc_update_prob: 0.67

# All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length]
mel_loss_resolutions: [
[32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]
]

train_ds:
dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch}
sample_rate: ${sample_rate}
n_samples: ${train_n_samples}
min_duration: 1.01
max_duration: null
dataset_meta: ${train_ds_meta}

dataloader_params:
batch_size: ${batch_size}
drop_last: true
num_workers: 4

validation_ds:
dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
sample_rate: ${sample_rate}
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss
dataset_meta: ${val_ds_meta}

dataloader_params:
batch_size: 8
num_workers: 2

# Configures how audio samples are generated and saved during training.
# Remove this section to disable logging.
log_config:
log_dir: ${log_dir}
log_epochs: [10, 50]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: true
log_quantized: true

dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
sample_rate: ${sample_rate}
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 15.0 # Only log the first 15 seconds of generated audio.
dataset_meta: ${log_ds_meta}

dataloader_params:
batch_size: 4
num_workers: 2

audio_encoder:
_target_: nemo.collections.tts.modules.audio_codec_modules.SEANetEncoder
down_sample_rates: ${down_sample_rates}

audio_decoder:
_target_: nemo.collections.tts.modules.audio_codec_modules.SEANetDecoder
up_sample_rates: ${up_sample_rates}

vector_quantizer:
_target_: nemo.collections.tts.modules.vector_quantization.ResidualVectorQuantizer
num_codebooks: 8

discriminator:
_target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT
resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]]

# The original EnCodec uses hinged loss, but squared-GAN loss is more stable
# and reduces the need to tune the loss weights or use a gradient balancer.
generator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss

discriminator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss

optim:
_target_: torch.optim.Adam
lr: 3e-4
betas: [0.5, 0.9]

sched:
name: ExponentialLR
gamma: 0.999

trainer:
num_nodes: 1
devices: 1
accelerator: gpu
strategy: ddp
precision: 32 # Vector quantization only works with 32-bit precision.
max_epochs: ${max_epochs}
accumulate_grad_batches: 1
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 5
benchmark: false

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
create_wandb_logger: false
checkpoint_callback_params:
monitor: val_loss
resume_if_exists: false
resume_ignore_no_checkpoint: false
1 change: 1 addition & 0 deletions examples/tts/conf/hifigan/hifigan_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ model:
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 15.0
dataset_meta: ${log_ds_meta}

dataloader_params:
Expand Down
11 changes: 11 additions & 0 deletions nemo/collections/tts/data/vocoder_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class VocoderDataset(Dataset):
will be ignored.
max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration'
will be ignored.
trunc_duration: Optional int, if provided audio will be truncated to at most 'trunc_duration' seconds.
num_audio_retries: Number of read attempts to make when sampling audio file, to avoid training failing
from sporadic IO errors.
"""
Expand All @@ -78,6 +79,7 @@ def __init__(
feature_processors: Optional[Dict[str, FeatureProcessor]] = None,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
trunc_duration: Optional[float] = None,
num_audio_retries: int = 5,
):
super().__init__()
Expand All @@ -88,6 +90,11 @@ def __init__(
self.num_audio_retries = num_audio_retries
self.load_precomputed_mel = False

if trunc_duration:
self.trunc_samples = int(trunc_duration * self.sample_rate)
else:
self.trunc_samples = None

if feature_processors:
logging.info(f"Found feature processors {feature_processors.keys()}")
self.feature_processors = list(feature_processors.values())
Expand Down Expand Up @@ -132,6 +139,10 @@ def _sample_audio(self, audio_filepath: Path) -> Tuple[torch.Tensor, torch.Tenso
else:
audio_segment = self._segment_audio(audio_filepath)
audio_array = audio_segment.samples

if self.trunc_samples:
audio_array = audio_array[: self.trunc_samples]

audio = torch.tensor(audio_array)
audio_len = torch.tensor(audio.shape[0])
return audio, audio_len
Expand Down
Loading
Loading