Skip to content

Commit

Permalink
[TTS] Add callback for saving audio during FastPitch training (NVIDIA…
Browse files Browse the repository at this point in the history
…#6665)

* [TTS] Add callback for saving audio during FastPitch training

Signed-off-by: Ryan <[email protected]>

* [TTS] Allow NGC model name for vocoder

Signed-off-by: Ryan <[email protected]>

---------

Signed-off-by: Ryan <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
rlangman authored and hsiehjackson committed Jun 2, 2023
1 parent 8ed6a0a commit a6b856c
Show file tree
Hide file tree
Showing 6 changed files with 517 additions and 15 deletions.
44 changes: 43 additions & 1 deletion examples/tts/conf/fastpitch/fastpitch_22050.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@ feature_stats_path: null

train_ds_meta: ???
val_ds_meta: ???
log_ds_meta: ???

phoneme_dict_path: ???
heteronyms_path: ???

log_dir: ???
vocoder_type: ???
vocoder_name: null
vocoder_checkpoint_path: null

defaults:
- feature: feature_22050

Expand All @@ -27,6 +33,7 @@ model:

n_speakers: ${n_speakers}
n_mel_channels: ${feature.mel_feature.mel_dim}
min_token_duration: 1
max_token_duration: 75
symbols_embedding_dim: 384
pitch_embedding_kernel_size: 3
Expand Down Expand Up @@ -126,7 +133,42 @@ model:

dataloader_params:
batch_size: ${batch_size}
drop_last: false
num_workers: 2

log_config:
log_dir: ${log_dir}
log_epochs: [10, 50]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.FastPitchArtifactGenerator
log_spectrogram: true
log_alignment: true
audio_params:
_target_: nemo.collections.tts.parts.utils.callbacks.LogAudioParams
log_audio_gta: true
vocoder_type: ${vocoder_type}
vocoder_name: ${vocoder_name}
vocoder_checkpoint_path: ${vocoder_checkpoint_path}

dataset:
_target_: nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset
text_tokenizer: ${model.text_tokenizer}
sample_rate: ${feature.sample_rate}
speaker_path: ${speaker_path}
align_prior_config: ${model.align_prior_config}
featurizers: ${feature.featurizers}

feature_processors:
pitch: ${model.pitch_processor}
energy: ${model.energy_processor}

dataset_meta: ${log_ds_meta}

dataloader_params:
batch_size: 8
num_workers: 2

input_fft:
Expand Down
15 changes: 9 additions & 6 deletions nemo/collections/tts/data/text_to_speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def _process_dataset(

sample = DatasetSample(
manifest_entry=entry,
audio_dir=dataset.audio_dir,
feature_dir=dataset.feature_dir,
audio_dir=Path(dataset.audio_dir),
feature_dir=Path(dataset.feature_dir),
text=text,
speaker=speaker,
speaker_index=speaker_index,
Expand All @@ -208,12 +208,12 @@ def __getitem__(self, index):
data = self.data_samples[index]

audio_filepath = Path(data.manifest_entry["audio_filepath"])
audio_path, _ = get_abs_rel_paths(input_path=audio_filepath, base_path=data.audio_dir)
audio_filepath_abs, audio_filepath_rel = get_abs_rel_paths(input_path=audio_filepath, base_path=data.audio_dir)

audio, _ = librosa.load(audio_path, sr=self.sample_rate)
audio, _ = librosa.load(audio_filepath_abs, sr=self.sample_rate)
tokens = self.text_tokenizer(data.text)

example = {"audio": audio, "tokens": tokens}
example = {"audio_filepath": audio_filepath_rel, "audio": audio, "tokens": tokens}

if data.speaker is not None:
example["speaker"] = data.speaker
Expand Down Expand Up @@ -243,7 +243,7 @@ def __getitem__(self, index):
return example

def collate_fn(self, batch: List[dict]):

audio_filepath_list = []
audio_list = []
audio_len_list = []
token_list = []
Expand All @@ -252,6 +252,8 @@ def collate_fn(self, batch: List[dict]):
prior_list = []

for example in batch:
audio_filepath_list.append(example["audio_filepath"])

audio_tensor = torch.tensor(example["audio"], dtype=torch.float32)
audio_list.append(audio_tensor)
audio_len_list.append(audio_tensor.shape[0])
Expand All @@ -276,6 +278,7 @@ def collate_fn(self, batch: List[dict]):
batch_tokens = stack_tensors(token_list, max_lens=[token_max_len], pad_value=self.text_tokenizer.pad)

batch_dict = {
"audio_filepaths": audio_filepath_list,
"audio": batch_audio,
"audio_lens": batch_audio_len,
"text": batch_tokens,
Expand Down
36 changes: 34 additions & 2 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import contextlib
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import torch
Expand All @@ -27,6 +28,7 @@
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.collections.tts.modules.fastpitch import FastPitchModule
from nemo.collections.tts.parts.mixins import FastPitchAdapterModelMixin
from nemo.collections.tts.parts.utils.callbacks import LoggingCallback
from nemo.collections.tts.parts.utils.helpers import (
batch_from_ragged,
plot_alignment_to_numpy,
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg=cfg, trainer=trainer)

self.bin_loss_warmup_epochs = cfg.get("bin_loss_warmup_epochs", 100)
self.log_images = cfg.get("log_images", False)
self.log_train_images = False

loss_scale = 0.1 if self.learn_alignment else 1.0
Expand Down Expand Up @@ -154,6 +157,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False)
speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False)
speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False)
min_token_duration = cfg.get("min_token_duration", 0)
use_log_energy = cfg.get("use_log_energy", True)
if n_speakers > 1 and "add" not in input_fft.cond_input.condition_types:
input_fft.cond_input.condition_types.append("add")
Expand All @@ -178,6 +182,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
cfg.pitch_embedding_kernel_size,
energy_embedding_kernel_size,
cfg.n_mel_channels,
min_token_duration,
cfg.max_token_duration,
use_log_energy,
)
Expand All @@ -190,6 +195,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if self.fastpitch.speaker_emb is not None:
self.export_config["num_speakers"] = cfg.n_speakers

self.log_config = cfg.get("log_config", None)

# Adapter modules setup (from FastPitchAdapterModelMixin)
self.setup_adapters()

Expand Down Expand Up @@ -462,7 +469,7 @@ def training_step(self, batch, batch_idx):
self.log("t_bin_loss", bin_loss)

# Log images to tensorboard
if self.log_train_images and isinstance(self.logger, TensorBoardLogger):
if self.log_images and self.log_train_images and isinstance(self.logger, TensorBoardLogger):
self.log_train_images = False

self.tb_logger.add_image(
Expand Down Expand Up @@ -571,7 +578,7 @@ def validation_epoch_end(self, outputs):

_, _, _, _, _, spec_target, spec_predict = outputs[0].values()

if isinstance(self.logger, TensorBoardLogger):
if self.log_images and isinstance(self.logger, TensorBoardLogger):
self.tb_logger.add_image(
"val_mel_target",
plot_spectrogram_to_numpy(spec_target[0].data.cpu().float().numpy()),
Expand Down Expand Up @@ -658,6 +665,31 @@ def setup_test_data(self, cfg):
"""Omitted."""
pass

def configure_callbacks(self):
if not self.log_config:
return []

sample_ds_class = self.log_config.dataset._target_
if sample_ds_class != "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset":
raise ValueError(f"Logging callback only supported for TextToSpeechDataset, got {sample_ds_class}")

data_loader = self._setup_test_dataloader(self.log_config)

generators = instantiate(self.log_config.generators)
log_dir = Path(self.log_config.log_dir) if self.log_config.log_dir else None
log_callback = LoggingCallback(
generators=generators,
data_loader=data_loader,
log_epochs=self.log_config.log_epochs,
epoch_frequency=self.log_config.epoch_frequency,
output_dir=log_dir,
loggers=self.trainer.loggers,
log_tensorboard=self.log_config.log_tensorboard,
log_wandb=self.log_config.log_wandb,
)

return [log_callback]

@classmethod
def list_available_models(cls) -> 'List[PretrainedModelInfo]':
"""
Expand Down
27 changes: 21 additions & 6 deletions nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def average_features(pitch, durs):
return pitch_avg


def log_to_duration(log_dur, min_dur, max_dur, mask):
dur = torch.clamp(torch.exp(log_dur) - 1.0, min_dur, max_dur)
dur *= mask.squeeze(2)
return dur


class ConvReLUNorm(torch.nn.Module, adapter_mixins.AdapterModuleMixin):
def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0, condition_dim=384, condition_types=[]):
super(ConvReLUNorm, self).__init__()
Expand Down Expand Up @@ -163,6 +169,7 @@ def __init__(
pitch_embedding_kernel_size: int,
energy_embedding_kernel_size: int,
n_mel_channels: int = 80,
min_token_duration: int = 0,
max_token_duration: int = 75,
use_log_energy: bool = True,
):
Expand All @@ -188,8 +195,8 @@ def __init__(
else:
self.speaker_emb = None

self.min_token_duration = min_token_duration
self.max_token_duration = max_token_duration
self.min_token_duration = 0

self.pitch_emb = torch.nn.Conv1d(
1,
Expand Down Expand Up @@ -294,7 +301,9 @@ def forward(

# Predict duration
log_durs_predicted = self.duration_predictor(enc_out, enc_mask, conditioning=spk_emb)
durs_predicted = torch.clamp(torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)
durs_predicted = log_to_duration(
log_dur=log_durs_predicted, min_dur=self.min_token_duration, max_dur=self.max_token_duration, mask=enc_mask
)

attn_soft, attn_hard, attn_hard_dur, attn_logprob = None, None, None, None
if self.learn_alignment and spec is not None:
Expand Down Expand Up @@ -398,8 +407,8 @@ def infer(

# Predict duration and pitch
log_durs_predicted = self.duration_predictor(enc_out, enc_mask, conditioning=spk_emb)
durs_predicted = torch.clamp(
torch.exp(log_durs_predicted) - 1.0, self.min_token_duration, self.max_token_duration
durs_predicted = log_to_duration(
log_dur=log_durs_predicted, min_dur=self.min_token_duration, max_dur=self.max_token_duration, mask=enc_mask
)
pitch_predicted = self.pitch_predictor(enc_out, enc_mask, conditioning=spk_emb) + pitch
pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
Expand Down Expand Up @@ -444,6 +453,7 @@ def __init__(
symbols_embedding_dim: int,
pitch_embedding_kernel_size: int,
n_mel_channels: int = 80,
min_token_duration: int = 0,
max_token_duration: int = 75,
):
super().__init__()
Expand All @@ -453,8 +463,8 @@ def __init__(
self.duration_predictor = duration_predictor
self.pitch_predictor = pitch_predictor

self.min_token_duration = min_token_duration
self.max_token_duration = max_token_duration
self.min_token_duration = 0

if self.pitch_predictor is not None:
self.pitch_emb = torch.nn.Conv1d(
Expand Down Expand Up @@ -497,7 +507,12 @@ def forward(self, *, enc_out=None, enc_mask=None, durs=None, pitch=None, pace=1.
log_durs_predicted, durs_predicted = None, None
if self.duration_predictor is not None:
log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
durs_predicted = torch.clamp(torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)
durs_predicted = log_to_duration(
log_dur=log_durs_predicted,
min_dur=self.min_token_duration,
max_dur=self.max_token_duration,
mask=enc_mask,
)

# Predict pitch
pitch_predicted = None
Expand Down
Loading

0 comments on commit a6b856c

Please sign in to comment.