Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Add STFT and SI-SDR loss to audio codec recipe #7468

Merged
merged 5 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions examples/tts/conf/audio_codec/audio_codec_24000.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# This config contains the default values for training 24khz audio codec model
# If you want to train model on other dataset, you can change config values according to your dataset.
# Most dataset-specific arguments are in the head of the config file, see below.

name: EnCodec

max_epochs: ???
# Adjust batch size based on GPU memory
batch_size: 16
# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch.
# If null, then weighted sampling is disabled.
weighted_sampling_steps_per_epoch: null

# Dataset metadata for each manifest
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41
train_ds_meta: ???
val_ds_meta: ???

log_ds_meta: ???
log_dir: ???

# Modify these values based on your sample rate
sample_rate: 24000
train_n_samples: 24000
down_sample_rates: [2, 4, 5, 8]
up_sample_rates: [8, 5, 4, 2]
redoctopus marked this conversation as resolved.
Show resolved Hide resolved
# The number of samples per encoded audio frame. Should be the product of the down_sample_rates.
# For example 2 * 4 * 5 * 8 = 320.
samples_per_frame: 320

model:

max_epochs: ${max_epochs}
steps_per_epoch: ${weighted_sampling_steps_per_epoch}

sample_rate: ${sample_rate}
samples_per_frame: ${samples_per_frame}

mel_loss_l1_scale: 15.0
mel_loss_l2_scale: 0.0
stft_loss_scale: 15.0
time_domain_loss_scale: 0.0

# Probability of updating the discriminator during each training step
# For example, update the discriminator 2/3 times (2 updates for every 3 batches)
disc_updates_per_period: 2
disc_update_period: 3

# All resolutions for reconstruction loss, ordered [num_fft, hop_length, window_length]
loss_resolutions: [
[32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]
]
mel_loss_dims: [5, 10, 20, 40, 80, 160, 320]
mel_loss_log_guard: 1.0
stft_loss_log_guard: 1.0

train_ds:
dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch}
sample_rate: ${sample_rate}
n_samples: ${train_n_samples}
min_duration: 1.01
max_duration: null
dataset_meta: ${train_ds_meta}

dataloader_params:
batch_size: ${batch_size}
drop_last: true
num_workers: 4

validation_ds:
dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
sample_rate: ${sample_rate}
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss
dataset_meta: ${val_ds_meta}

dataloader_params:
batch_size: 8
num_workers: 2

# Configures how audio samples are generated and saved during training.
# Remove this section to disable logging.
log_config:
log_dir: ${log_dir}
log_epochs: [10, 50]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: true
log_dequantized: true

dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
sample_rate: ${sample_rate}
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 15.0 # Only log the first 15 seconds of generated audio.
dataset_meta: ${log_ds_meta}

dataloader_params:
batch_size: 4
num_workers: 2

audio_encoder:
_target_: nemo.collections.tts.modules.encodec_modules.HifiGanEncoder
down_sample_rates: ${down_sample_rates}

audio_decoder:
_target_: nemo.collections.tts.modules.encodec_modules.SEANetDecoder
up_sample_rates: ${up_sample_rates}

vector_quantizer:
_target_: nemo.collections.tts.modules.encodec_modules.ResidualVectorQuantizer
num_codebooks: 8

discriminator:
_target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT
resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]]

# The original EnCodec uses hinged loss, but squared-GAN loss is more stable
# and reduces the need to tune the loss weights or use a gradient balancer.
generator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss

discriminator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss

optim:
_target_: torch.optim.Adam
lr: 3e-4
betas: [0.5, 0.9]

sched:
name: ExponentialLR
gamma: 0.998

trainer:
num_nodes: 1
devices: 1
accelerator: gpu
strategy: ddp_find_unused_parameters_true
precision: 32 # Vector quantization only works with 32-bit precision.
max_epochs: ${max_epochs}
accumulate_grad_batches: 1
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 10
benchmark: false

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
create_wandb_logger: false
checkpoint_callback_params:
monitor: val_loss
resume_if_exists: false
resume_ignore_no_checkpoint: false
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,23 @@ model:
sample_rate: ${sample_rate}
samples_per_frame: ${samples_per_frame}

mel_loss_scale: 5.0
mel_loss_l1_scale: 1.0
mel_loss_l2_scale: 1.0
stft_loss_scale: 0.0
time_domain_loss_scale: 0.1

# Probability of updating the discriminator during each training step
# For example, update the discriminator 2/3 times (2 updates for every 3 batches)
disc_updates_per_period: 2
disc_update_period: 3

# All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length]
mel_loss_resolutions: [
# All resolutions for reconstruction loss, ordered [num_fft, hop_length, window_length]
loss_resolutions: [
[32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]
]
mel_loss_dims: [64, 64, 64, 64, 64, 64, 64]
mel_loss_log_guard: 1E-5
stft_loss_log_guard: 1.0

train_ds:
dataset:
Expand Down
Loading