Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Add callback for saving audio during FastPitch training #6665

Merged
merged 2 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion examples/tts/conf/fastpitch/fastpitch_22050.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@ feature_stats_path: null

train_ds_meta: ???
val_ds_meta: ???
log_ds_meta: ???

phoneme_dict_path: ???
heteronyms_path: ???

log_dir: ???
vocoder_type: ???
vocoder_name: null
vocoder_checkpoint_path: null

defaults:
- feature: feature_22050

Expand All @@ -27,6 +33,7 @@ model:

n_speakers: ${n_speakers}
n_mel_channels: ${feature.mel_feature.mel_dim}
min_token_duration: 1
max_token_duration: 75
symbols_embedding_dim: 384
pitch_embedding_kernel_size: 3
Expand Down Expand Up @@ -126,7 +133,42 @@ model:

dataloader_params:
batch_size: ${batch_size}
drop_last: false
num_workers: 2

log_config:
log_dir: ${log_dir}
log_epochs: [10, 50]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.FastPitchArtifactGenerator
log_spectrogram: true
log_alignment: true
audio_params:
_target_: nemo.collections.tts.parts.utils.callbacks.LogAudioParams
log_audio_gta: true
vocoder_type: ${vocoder_type}
vocoder_name: ${vocoder_name}
vocoder_checkpoint_path: ${vocoder_checkpoint_path}
rlangman marked this conversation as resolved.
Show resolved Hide resolved

dataset:
_target_: nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset
text_tokenizer: ${model.text_tokenizer}
sample_rate: ${feature.sample_rate}
speaker_path: ${speaker_path}
align_prior_config: ${model.align_prior_config}
featurizers: ${feature.featurizers}

feature_processors:
pitch: ${model.pitch_processor}
energy: ${model.energy_processor}

dataset_meta: ${log_ds_meta}

dataloader_params:
batch_size: 8
num_workers: 2

input_fft:
Expand Down
15 changes: 9 additions & 6 deletions nemo/collections/tts/data/text_to_speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def _process_dataset(

sample = DatasetSample(
manifest_entry=entry,
audio_dir=dataset.audio_dir,
feature_dir=dataset.feature_dir,
audio_dir=Path(dataset.audio_dir),
feature_dir=Path(dataset.feature_dir),
text=text,
speaker=speaker,
speaker_index=speaker_index,
Expand All @@ -208,12 +208,12 @@ def __getitem__(self, index):
data = self.data_samples[index]

audio_filepath = Path(data.manifest_entry["audio_filepath"])
audio_path, _ = get_abs_rel_paths(input_path=audio_filepath, base_path=data.audio_dir)
audio_filepath_abs, audio_filepath_rel = get_abs_rel_paths(input_path=audio_filepath, base_path=data.audio_dir)

audio, _ = librosa.load(audio_path, sr=self.sample_rate)
audio, _ = librosa.load(audio_filepath_abs, sr=self.sample_rate)
tokens = self.text_tokenizer(data.text)

example = {"audio": audio, "tokens": tokens}
example = {"audio_filepath": audio_filepath_rel, "audio": audio, "tokens": tokens}

if data.speaker is not None:
example["speaker"] = data.speaker
Expand Down Expand Up @@ -243,7 +243,7 @@ def __getitem__(self, index):
return example

def collate_fn(self, batch: List[dict]):

audio_filepath_list = []
audio_list = []
audio_len_list = []
token_list = []
Expand All @@ -252,6 +252,8 @@ def collate_fn(self, batch: List[dict]):
prior_list = []

for example in batch:
audio_filepath_list.append(example["audio_filepath"])

audio_tensor = torch.tensor(example["audio"], dtype=torch.float32)
audio_list.append(audio_tensor)
audio_len_list.append(audio_tensor.shape[0])
Expand All @@ -276,6 +278,7 @@ def collate_fn(self, batch: List[dict]):
batch_tokens = stack_tensors(token_list, max_lens=[token_max_len], pad_value=self.text_tokenizer.pad)

batch_dict = {
"audio_filepaths": audio_filepath_list,
"audio": batch_audio,
"audio_lens": batch_audio_len,
"text": batch_tokens,
Expand Down
36 changes: 34 additions & 2 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import contextlib
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import torch
Expand All @@ -27,6 +28,7 @@
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.collections.tts.modules.fastpitch import FastPitchModule
from nemo.collections.tts.parts.mixins import FastPitchAdapterModelMixin
from nemo.collections.tts.parts.utils.callbacks import LoggingCallback
from nemo.collections.tts.parts.utils.helpers import (
batch_from_ragged,
plot_alignment_to_numpy,
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg=cfg, trainer=trainer)

self.bin_loss_warmup_epochs = cfg.get("bin_loss_warmup_epochs", 100)
self.log_images = cfg.get("log_images", False)
self.log_train_images = False

loss_scale = 0.1 if self.learn_alignment else 1.0
Expand Down Expand Up @@ -154,6 +157,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False)
speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False)
speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False)
min_token_duration = cfg.get("min_token_duration", 0)
use_log_energy = cfg.get("use_log_energy", True)
if n_speakers > 1 and "add" not in input_fft.cond_input.condition_types:
input_fft.cond_input.condition_types.append("add")
Expand All @@ -178,6 +182,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
cfg.pitch_embedding_kernel_size,
energy_embedding_kernel_size,
cfg.n_mel_channels,
min_token_duration,
cfg.max_token_duration,
use_log_energy,
)
Expand All @@ -190,6 +195,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if self.fastpitch.speaker_emb is not None:
self.export_config["num_speakers"] = cfg.n_speakers

self.log_config = cfg.get("log_config", None)

# Adapter modules setup (from FastPitchAdapterModelMixin)
self.setup_adapters()

Expand Down Expand Up @@ -462,7 +469,7 @@ def training_step(self, batch, batch_idx):
self.log("t_bin_loss", bin_loss)

# Log images to tensorboard
if self.log_train_images and isinstance(self.logger, TensorBoardLogger):
if self.log_images and self.log_train_images and isinstance(self.logger, TensorBoardLogger):
self.log_train_images = False

self.tb_logger.add_image(
Expand Down Expand Up @@ -571,7 +578,7 @@ def validation_epoch_end(self, outputs):

_, _, _, _, _, spec_target, spec_predict = outputs[0].values()

if isinstance(self.logger, TensorBoardLogger):
if self.log_images and isinstance(self.logger, TensorBoardLogger):
self.tb_logger.add_image(
"val_mel_target",
plot_spectrogram_to_numpy(spec_target[0].data.cpu().float().numpy()),
Expand Down Expand Up @@ -658,6 +665,31 @@ def setup_test_data(self, cfg):
"""Omitted."""
pass

def configure_callbacks(self):
Fixed Show fixed Hide fixed
if not self.log_config:
return []

sample_ds_class = self.log_config.dataset._target_
if sample_ds_class != "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset":
raise ValueError(f"Logging callback only supported for TextToSpeechDataset, got {sample_ds_class}")

data_loader = self._setup_test_dataloader(self.log_config)

generators = instantiate(self.log_config.generators)
log_dir = Path(self.log_config.log_dir) if self.log_config.log_dir else None
log_callback = LoggingCallback(
generators=generators,
data_loader=data_loader,
log_epochs=self.log_config.log_epochs,
epoch_frequency=self.log_config.epoch_frequency,
output_dir=log_dir,
loggers=self.trainer.loggers,
log_tensorboard=self.log_config.log_tensorboard,
log_wandb=self.log_config.log_wandb,
)

return [log_callback]

@classmethod
def list_available_models(cls) -> 'List[PretrainedModelInfo]':
"""
Expand Down
27 changes: 21 additions & 6 deletions nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def average_features(pitch, durs):
return pitch_avg


def log_to_duration(log_dur, min_dur, max_dur, mask):
dur = torch.clamp(torch.exp(log_dur) - 1.0, min_dur, max_dur)
dur *= mask.squeeze(2)
return dur


class ConvReLUNorm(torch.nn.Module, adapter_mixins.AdapterModuleMixin):
def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0, condition_dim=384, condition_types=[]):
super(ConvReLUNorm, self).__init__()
Expand Down Expand Up @@ -163,6 +169,7 @@ def __init__(
pitch_embedding_kernel_size: int,
energy_embedding_kernel_size: int,
n_mel_channels: int = 80,
min_token_duration: int = 0,
max_token_duration: int = 75,
use_log_energy: bool = True,
):
Expand All @@ -188,8 +195,8 @@ def __init__(
else:
self.speaker_emb = None

self.min_token_duration = min_token_duration
self.max_token_duration = max_token_duration
self.min_token_duration = 0

self.pitch_emb = torch.nn.Conv1d(
1,
Expand Down Expand Up @@ -294,7 +301,9 @@ def forward(

# Predict duration
log_durs_predicted = self.duration_predictor(enc_out, enc_mask, conditioning=spk_emb)
durs_predicted = torch.clamp(torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)
durs_predicted = log_to_duration(
log_dur=log_durs_predicted, min_dur=self.min_token_duration, max_dur=self.max_token_duration, mask=enc_mask
)

attn_soft, attn_hard, attn_hard_dur, attn_logprob = None, None, None, None
if self.learn_alignment and spec is not None:
Expand Down Expand Up @@ -398,8 +407,8 @@ def infer(

# Predict duration and pitch
log_durs_predicted = self.duration_predictor(enc_out, enc_mask, conditioning=spk_emb)
durs_predicted = torch.clamp(
torch.exp(log_durs_predicted) - 1.0, self.min_token_duration, self.max_token_duration
durs_predicted = log_to_duration(
log_dur=log_durs_predicted, min_dur=self.min_token_duration, max_dur=self.max_token_duration, mask=enc_mask
)
pitch_predicted = self.pitch_predictor(enc_out, enc_mask, conditioning=spk_emb) + pitch
pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
Expand Down Expand Up @@ -444,6 +453,7 @@ def __init__(
symbols_embedding_dim: int,
pitch_embedding_kernel_size: int,
n_mel_channels: int = 80,
min_token_duration: int = 0,
max_token_duration: int = 75,
):
super().__init__()
Expand All @@ -453,8 +463,8 @@ def __init__(
self.duration_predictor = duration_predictor
self.pitch_predictor = pitch_predictor

self.min_token_duration = min_token_duration
self.max_token_duration = max_token_duration
self.min_token_duration = 0

if self.pitch_predictor is not None:
self.pitch_emb = torch.nn.Conv1d(
Expand Down Expand Up @@ -497,7 +507,12 @@ def forward(self, *, enc_out=None, enc_mask=None, durs=None, pitch=None, pace=1.
log_durs_predicted, durs_predicted = None, None
if self.duration_predictor is not None:
log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
durs_predicted = torch.clamp(torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)
durs_predicted = log_to_duration(
log_dur=log_durs_predicted,
min_dur=self.min_token_duration,
max_dur=self.max_token_duration,
mask=enc_mask,
)

# Predict pitch
pitch_predicted = None
Expand Down
Loading
Loading