From a6b856c1cc7c48ba52c013713e361218d197a684 Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Thu, 18 May 2023 09:39:56 -0700 Subject: [PATCH] [TTS] Add callback for saving audio during FastPitch training (#6665) * [TTS] Add callback for saving audio during FastPitch training Signed-off-by: Ryan * [TTS] Allow NGC model name for vocoder Signed-off-by: Ryan --------- Signed-off-by: Ryan Signed-off-by: hsiehjackson --- .../tts/conf/fastpitch/fastpitch_22050.yaml | 44 +- .../tts/data/text_to_speech_dataset.py | 15 +- nemo/collections/tts/models/fastpitch.py | 36 +- nemo/collections/tts/modules/fastpitch.py | 27 +- nemo/collections/tts/parts/utils/callbacks.py | 393 ++++++++++++++++++ nemo/collections/tts/parts/utils/helpers.py | 17 + 6 files changed, 517 insertions(+), 15 deletions(-) create mode 100644 nemo/collections/tts/parts/utils/callbacks.py diff --git a/examples/tts/conf/fastpitch/fastpitch_22050.yaml b/examples/tts/conf/fastpitch/fastpitch_22050.yaml index 016e157ce39f..4022e8e91c97 100644 --- a/examples/tts/conf/fastpitch/fastpitch_22050.yaml +++ b/examples/tts/conf/fastpitch/fastpitch_22050.yaml @@ -14,10 +14,16 @@ feature_stats_path: null train_ds_meta: ??? val_ds_meta: ??? +log_ds_meta: ??? phoneme_dict_path: ??? heteronyms_path: ??? +log_dir: ??? +vocoder_type: ??? +vocoder_name: null +vocoder_checkpoint_path: null + defaults: - feature: feature_22050 @@ -27,6 +33,7 @@ model: n_speakers: ${n_speakers} n_mel_channels: ${feature.mel_feature.mel_dim} + min_token_duration: 1 max_token_duration: 75 symbols_embedding_dim: 384 pitch_embedding_kernel_size: 3 @@ -126,7 +133,42 @@ model: dataloader_params: batch_size: ${batch_size} - drop_last: false + num_workers: 2 + + log_config: + log_dir: ${log_dir} + log_epochs: [10, 50] + epoch_frequency: 100 + log_tensorboard: false + log_wandb: false + + generators: + - _target_: nemo.collections.tts.parts.utils.callbacks.FastPitchArtifactGenerator + log_spectrogram: true + log_alignment: true + audio_params: + _target_: nemo.collections.tts.parts.utils.callbacks.LogAudioParams + log_audio_gta: true + vocoder_type: ${vocoder_type} + vocoder_name: ${vocoder_name} + vocoder_checkpoint_path: ${vocoder_checkpoint_path} + + dataset: + _target_: nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset + text_tokenizer: ${model.text_tokenizer} + sample_rate: ${feature.sample_rate} + speaker_path: ${speaker_path} + align_prior_config: ${model.align_prior_config} + featurizers: ${feature.featurizers} + + feature_processors: + pitch: ${model.pitch_processor} + energy: ${model.energy_processor} + + dataset_meta: ${log_ds_meta} + + dataloader_params: + batch_size: 8 num_workers: 2 input_fft: diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index f6230fa3493a..47868d41d1ec 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -190,8 +190,8 @@ def _process_dataset( sample = DatasetSample( manifest_entry=entry, - audio_dir=dataset.audio_dir, - feature_dir=dataset.feature_dir, + audio_dir=Path(dataset.audio_dir), + feature_dir=Path(dataset.feature_dir), text=text, speaker=speaker, speaker_index=speaker_index, @@ -208,12 +208,12 @@ def __getitem__(self, index): data = self.data_samples[index] audio_filepath = Path(data.manifest_entry["audio_filepath"]) - audio_path, _ = get_abs_rel_paths(input_path=audio_filepath, base_path=data.audio_dir) + audio_filepath_abs, audio_filepath_rel = get_abs_rel_paths(input_path=audio_filepath, base_path=data.audio_dir) - audio, _ = librosa.load(audio_path, sr=self.sample_rate) + audio, _ = librosa.load(audio_filepath_abs, sr=self.sample_rate) tokens = self.text_tokenizer(data.text) - example = {"audio": audio, "tokens": tokens} + example = {"audio_filepath": audio_filepath_rel, "audio": audio, "tokens": tokens} if data.speaker is not None: example["speaker"] = data.speaker @@ -243,7 +243,7 @@ def __getitem__(self, index): return example def collate_fn(self, batch: List[dict]): - + audio_filepath_list = [] audio_list = [] audio_len_list = [] token_list = [] @@ -252,6 +252,8 @@ def collate_fn(self, batch: List[dict]): prior_list = [] for example in batch: + audio_filepath_list.append(example["audio_filepath"]) + audio_tensor = torch.tensor(example["audio"], dtype=torch.float32) audio_list.append(audio_tensor) audio_len_list.append(audio_tensor.shape[0]) @@ -276,6 +278,7 @@ def collate_fn(self, batch: List[dict]): batch_tokens = stack_tensors(token_list, max_lens=[token_max_len], pad_value=self.text_tokenizer.pad) batch_dict = { + "audio_filepaths": audio_filepath_list, "audio": batch_audio, "audio_lens": batch_audio_len, "text": batch_tokens, diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index 281a7c2891b3..3939c9453911 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib from dataclasses import dataclass +from pathlib import Path from typing import List, Optional import torch @@ -27,6 +28,7 @@ from nemo.collections.tts.models.base import SpectrogramGenerator from nemo.collections.tts.modules.fastpitch import FastPitchModule from nemo.collections.tts.parts.mixins import FastPitchAdapterModelMixin +from nemo.collections.tts.parts.utils.callbacks import LoggingCallback from nemo.collections.tts.parts.utils.helpers import ( batch_from_ragged, plot_alignment_to_numpy, @@ -115,6 +117,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) self.bin_loss_warmup_epochs = cfg.get("bin_loss_warmup_epochs", 100) + self.log_images = cfg.get("log_images", False) self.log_train_images = False loss_scale = 0.1 if self.learn_alignment else 1.0 @@ -154,6 +157,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False) speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False) speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False) + min_token_duration = cfg.get("min_token_duration", 0) use_log_energy = cfg.get("use_log_energy", True) if n_speakers > 1 and "add" not in input_fft.cond_input.condition_types: input_fft.cond_input.condition_types.append("add") @@ -178,6 +182,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): cfg.pitch_embedding_kernel_size, energy_embedding_kernel_size, cfg.n_mel_channels, + min_token_duration, cfg.max_token_duration, use_log_energy, ) @@ -190,6 +195,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if self.fastpitch.speaker_emb is not None: self.export_config["num_speakers"] = cfg.n_speakers + self.log_config = cfg.get("log_config", None) + # Adapter modules setup (from FastPitchAdapterModelMixin) self.setup_adapters() @@ -462,7 +469,7 @@ def training_step(self, batch, batch_idx): self.log("t_bin_loss", bin_loss) # Log images to tensorboard - if self.log_train_images and isinstance(self.logger, TensorBoardLogger): + if self.log_images and self.log_train_images and isinstance(self.logger, TensorBoardLogger): self.log_train_images = False self.tb_logger.add_image( @@ -571,7 +578,7 @@ def validation_epoch_end(self, outputs): _, _, _, _, _, spec_target, spec_predict = outputs[0].values() - if isinstance(self.logger, TensorBoardLogger): + if self.log_images and isinstance(self.logger, TensorBoardLogger): self.tb_logger.add_image( "val_mel_target", plot_spectrogram_to_numpy(spec_target[0].data.cpu().float().numpy()), @@ -658,6 +665,31 @@ def setup_test_data(self, cfg): """Omitted.""" pass + def configure_callbacks(self): + if not self.log_config: + return [] + + sample_ds_class = self.log_config.dataset._target_ + if sample_ds_class != "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset": + raise ValueError(f"Logging callback only supported for TextToSpeechDataset, got {sample_ds_class}") + + data_loader = self._setup_test_dataloader(self.log_config) + + generators = instantiate(self.log_config.generators) + log_dir = Path(self.log_config.log_dir) if self.log_config.log_dir else None + log_callback = LoggingCallback( + generators=generators, + data_loader=data_loader, + log_epochs=self.log_config.log_epochs, + epoch_frequency=self.log_config.epoch_frequency, + output_dir=log_dir, + loggers=self.trainer.loggers, + log_tensorboard=self.log_config.log_tensorboard, + log_wandb=self.log_config.log_wandb, + ) + + return [log_callback] + @classmethod def list_available_models(cls) -> 'List[PretrainedModelInfo]': """ diff --git a/nemo/collections/tts/modules/fastpitch.py b/nemo/collections/tts/modules/fastpitch.py index b26aafa72e32..f7601302d81e 100644 --- a/nemo/collections/tts/modules/fastpitch.py +++ b/nemo/collections/tts/modules/fastpitch.py @@ -80,6 +80,12 @@ def average_features(pitch, durs): return pitch_avg +def log_to_duration(log_dur, min_dur, max_dur, mask): + dur = torch.clamp(torch.exp(log_dur) - 1.0, min_dur, max_dur) + dur *= mask.squeeze(2) + return dur + + class ConvReLUNorm(torch.nn.Module, adapter_mixins.AdapterModuleMixin): def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0, condition_dim=384, condition_types=[]): super(ConvReLUNorm, self).__init__() @@ -163,6 +169,7 @@ def __init__( pitch_embedding_kernel_size: int, energy_embedding_kernel_size: int, n_mel_channels: int = 80, + min_token_duration: int = 0, max_token_duration: int = 75, use_log_energy: bool = True, ): @@ -188,8 +195,8 @@ def __init__( else: self.speaker_emb = None + self.min_token_duration = min_token_duration self.max_token_duration = max_token_duration - self.min_token_duration = 0 self.pitch_emb = torch.nn.Conv1d( 1, @@ -294,7 +301,9 @@ def forward( # Predict duration log_durs_predicted = self.duration_predictor(enc_out, enc_mask, conditioning=spk_emb) - durs_predicted = torch.clamp(torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration) + durs_predicted = log_to_duration( + log_dur=log_durs_predicted, min_dur=self.min_token_duration, max_dur=self.max_token_duration, mask=enc_mask + ) attn_soft, attn_hard, attn_hard_dur, attn_logprob = None, None, None, None if self.learn_alignment and spec is not None: @@ -398,8 +407,8 @@ def infer( # Predict duration and pitch log_durs_predicted = self.duration_predictor(enc_out, enc_mask, conditioning=spk_emb) - durs_predicted = torch.clamp( - torch.exp(log_durs_predicted) - 1.0, self.min_token_duration, self.max_token_duration + durs_predicted = log_to_duration( + log_dur=log_durs_predicted, min_dur=self.min_token_duration, max_dur=self.max_token_duration, mask=enc_mask ) pitch_predicted = self.pitch_predictor(enc_out, enc_mask, conditioning=spk_emb) + pitch pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) @@ -444,6 +453,7 @@ def __init__( symbols_embedding_dim: int, pitch_embedding_kernel_size: int, n_mel_channels: int = 80, + min_token_duration: int = 0, max_token_duration: int = 75, ): super().__init__() @@ -453,8 +463,8 @@ def __init__( self.duration_predictor = duration_predictor self.pitch_predictor = pitch_predictor + self.min_token_duration = min_token_duration self.max_token_duration = max_token_duration - self.min_token_duration = 0 if self.pitch_predictor is not None: self.pitch_emb = torch.nn.Conv1d( @@ -497,7 +507,12 @@ def forward(self, *, enc_out=None, enc_mask=None, durs=None, pitch=None, pace=1. log_durs_predicted, durs_predicted = None, None if self.duration_predictor is not None: log_durs_predicted = self.duration_predictor(enc_out, enc_mask) - durs_predicted = torch.clamp(torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration) + durs_predicted = log_to_duration( + log_dur=log_durs_predicted, + min_dur=self.min_token_duration, + max_dur=self.max_token_duration, + mask=enc_mask, + ) # Predict pitch pitch_predicted = None diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py new file mode 100644 index 000000000000..0f8bd0fa4177 --- /dev/null +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -0,0 +1,393 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Type + +import librosa +import numpy as np +import soundfile as sf +import torch +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.loggers.wandb import WandbLogger + +from nemo.collections.tts.parts.utils.helpers import create_plot +from nemo.utils.decorators import experimental + +HAVE_WANDB = True +try: + import wandb +except ModuleNotFoundError: + HAVE_WANDB = False + + +def _get_logger(loggers: List[Logger], logger_type: Type[Logger]): + for logger in loggers: + if isinstance(logger, logger_type): + if hasattr(logger, "experiment"): + return logger.experiment + else: + return logger + raise ValueError(f"Could not find {logger_type} logger in {loggers}.") + + +def _load_vocoder(model_name: Optional[str], checkpoint_path: Optional[str], type: str): + assert (model_name is None) != ( + checkpoint_path is None + ), f"Must provide exactly one of vocoder model_name or checkpoint: ({model_name}, {checkpoint_path})" + + checkpoint_path = str(checkpoint_path) + if type == "hifigan": + from nemo.collections.tts.models import HifiGanModel + + model_type = HifiGanModel + elif type == "univnet": + from nemo.collections.tts.models import UnivNetModel + + model_type = UnivNetModel + else: + raise ValueError(f"Unknown vocoder type '{type}'") + + if model_name is not None: + vocoder = model_type.from_pretrained(model_name).eval() + else: + vocoder = model_type.load_from_checkpoint(checkpoint_path).eval() + + return vocoder + + +@dataclass +class AudioArtifact: + id: str + data: np.ndarray + sample_rate: int + filename: str + + +@dataclass +class ImageArtifact: + id: str + data: np.ndarray + filename: str + x_axis: str + y_axis: str + + +@dataclass +class LogAudioParams: + vocoder_type: str + vocoder_name: str + vocoder_checkpoint_path: str + log_audio_gta: bool = False + + +def create_id(filepath: Path) -> str: + path_prefix = str(filepath.with_suffix("")) + file_id = path_prefix.replace(os.sep, "_") + return file_id + + +class ArtifactGenerator(ABC): + @abstractmethod + def generate_artifacts( + self, model: LightningModule, batch_dict: Dict + ) -> Tuple[List[AudioArtifact], List[ImageArtifact]]: + """ + Create artifacts for the input model and test batch. + + Args: + model: Model instance being trained to use for inference. + batch_dict: Test batch to generate artifacts for. + + Returns: + List of audio and image artifacts to log. + """ + + +@experimental +class LoggingCallback(Callback): + """ + Callback which can log artifacts (eg. model predictions, graphs) to local disk, Tensorboard, and/or WandB. + + Args: + generators: List of generators to create and log artifacts from. + data_loader: Data to log artifacts for. + log_epochs: Optional list of specific training epoch numbers to log artifacts for. + epoch_frequency: Frequency with which to log + output_dir: Optional local directory. If provided, artifacts will be saved in output_dir. + loggers: Optional list of loggers to use if logging to tensorboard or wandb. + log_tensorboard: Whether to log artifacts to tensorboard. + log_wandb: Whether to log artifacts to WandB. + """ + + def __init__( + self, + generators: List[ArtifactGenerator], + data_loader: torch.utils.data.DataLoader, + log_epochs: Optional[List[int]] = None, + epoch_frequency: int = 1, + output_dir: Optional[Path] = None, + loggers: Optional[List[Logger]] = None, + log_tensorboard: bool = False, + log_wandb: bool = False, + ): + self.generators = generators + self.data_loader = data_loader + self.log_epochs = log_epochs if log_epochs else [] + self.epoch_frequency = epoch_frequency + self.output_dir = Path(output_dir) if output_dir else None + self.loggers = loggers if loggers else [] + self.log_tensorboard = log_tensorboard + self.log_wandb = log_wandb + + if log_tensorboard: + self.tensorboard_logger = _get_logger(self.loggers, TensorBoardLogger) + else: + self.tensorboard_logger = None + + if log_wandb: + if not HAVE_WANDB: + raise ValueError("Wandb not installed.") + self.wandb_logger = _get_logger(self.loggers, WandbLogger) + else: + self.wandb_logger = None + + def _log_audio(self, audio: AudioArtifact, log_dir: Path, step: int): + if log_dir: + filepath = log_dir / audio.filename + sf.write(file=filepath, data=audio.data, samplerate=audio.sample_rate) + + if self.tensorboard_logger: + self.tensorboard_logger.add_audio( + tag=audio.id, snd_tensor=audio.data, global_step=step, sample_rate=audio.sample_rate, + ) + + if self.wandb_logger: + wandb_audio = (wandb.Audio(audio.data, sample_rate=audio.sample_rate, caption=audio.id),) + self.wandb_logger.log({audio.id: wandb_audio}) + + def _log_image(self, image: ImageArtifact, log_dir: Path, step: int): + if log_dir: + filepath = log_dir / image.filename + else: + filepath = None + + image_plot = create_plot(output_filepath=filepath, data=image.data, x_axis=image.x_axis, y_axis=image.y_axis) + + if self.tensorboard_logger: + self.tensorboard_logger.add_image( + tag=image.id, img_tensor=image_plot, global_step=step, dataformats="HWC", + ) + + if self.wandb_logger: + wandb_image = (wandb.Image(image_plot, caption=image.id),) + self.wandb_logger.log({image.id: wandb_image}) + + def on_train_epoch_end(self, trainer: Trainer, model: LightningModule): + epoch = 1 + model.current_epoch + if (epoch not in self.log_epochs) and (epoch % self.epoch_frequency != 0): + return + + if self.output_dir: + log_dir = self.output_dir / f"epoch_{epoch}" + log_dir.mkdir(parents=True, exist_ok=True) + else: + log_dir = None + + audio_list = [] + image_list = [] + for batch_dict in self.data_loader: + for key, value in batch_dict.items(): + if isinstance(value, torch.Tensor): + batch_dict[key] = value.to(model.device) + + for generator in self.generators: + audio, images = generator.generate_artifacts(model=model, batch_dict=batch_dict) + audio_list += audio + image_list += images + + for audio in audio_list: + self._log_audio(audio=audio, log_dir=log_dir, step=model.global_step) + + for image in image_list: + self._log_image(image=image, log_dir=log_dir, step=model.global_step) + + +class FastPitchArtifactGenerator(ArtifactGenerator): + """ + Generator for logging FastPitch model outputs. + + Args: + log_spectrogram: Whether to log predicted spectrograms. + log_alignment: Whether to log alignment graphs. + audio_params: Optional parameters for saving predicted audio. + Requires a vocoder model checkpoint for generating audio from predicted spectrograms. + """ + + def __init__( + self, + log_spectrogram: bool = False, + log_alignment: bool = False, + audio_params: Optional[LogAudioParams] = None, + ): + self.log_spectrogram = log_spectrogram + self.log_alignment = log_alignment + + if not audio_params: + self.log_audio = False + self.log_audio_gta = False + self.vocoder = None + else: + self.log_audio = True + self.log_audio_gta = audio_params.log_audio_gta + self.vocoder = _load_vocoder( + model_name=audio_params.vocoder_name, + checkpoint_path=audio_params.vocoder_checkpoint_path, + type=audio_params.vocoder_type, + ) + + def _generate_audio(self, mels, mels_len, hop_length): + voc_input = mels.to(self.vocoder.device) + with torch.no_grad(): + audio_pred = self.vocoder.convert_spectrogram_to_audio(spec=voc_input) + + mels_len_array = mels_len.cpu().numpy() + audio_pred_lens = librosa.core.frames_to_samples(mels_len_array, hop_length=hop_length) + return audio_pred, audio_pred_lens + + def _generate_predictions(self, model: LightningModule, audio_ids: List[str], batch_dict: Dict): + audio_artifacts = [] + image_artifacts = [] + + text = batch_dict.get("text") + text_lens = batch_dict.get("text_lens") + speaker = batch_dict.get("speaker_id", None) + + with torch.no_grad(): + # [B, C, T_spec] + mels_pred, mels_pred_len, *_ = model.forward(text=text, input_lens=text_lens, speaker=speaker,) + + if self.log_spectrogram: + for i, audio_id in enumerate(audio_ids): + spec_i = mels_pred[i][:, : mels_pred_len[i]].cpu().numpy() + spec_artifact = ImageArtifact( + id=f"spec_{audio_id}", + data=spec_i, + filename=f"{audio_id}_spec.png", + x_axis="Audio Frames", + y_axis="Channels", + ) + image_artifacts.append(spec_artifact) + + if self.log_audio: + # [B, T_audio] + audio_pred, audio_pred_lens = self._generate_audio( + mels=mels_pred, mels_len=mels_pred_len, hop_length=model.preprocessor.hop_length + ) + for i, audio_id in enumerate(audio_ids): + audio_pred_i = audio_pred[i][: audio_pred_lens[i]].cpu().numpy() + audio_artifact = AudioArtifact( + id=f"audio_{audio_id}", + data=audio_pred_i, + filename=f"{audio_id}.wav", + sample_rate=self.vocoder.sample_rate, + ) + audio_artifacts.append(audio_artifact) + + return audio_artifacts, image_artifacts + + def _generate_gta_predictions(self, model: LightningModule, audio_ids: List[str], batch_dict: Dict): + audio_artifacts = [] + image_artifacts = [] + + audio = batch_dict.get("audio") + audio_lens = batch_dict.get("audio_lens") + text = batch_dict.get("text") + text_lens = batch_dict.get("text_lens") + attn_prior = batch_dict.get("align_prior_matrix", None) + pitch = batch_dict.get("pitch", None) + energy = batch_dict.get("energy", None) + speaker = batch_dict.get("speaker_id", None) + + mels, spec_len = model.preprocessor(input_signal=audio, length=audio_lens) + with torch.no_grad(): + mels_pred, mels_pred_len, _, _, _, attn, _, _, _, _, _, _ = model.forward( + text=text, + input_lens=text_lens, + pitch=pitch, + energy=energy, + speaker=speaker, + spec=mels, + mel_lens=spec_len, + attn_prior=attn_prior, + ) + + if self.log_alignment: + # [B, T_spec, T_text] + attn = attn.squeeze(1) + for i, audio_id in enumerate(audio_ids): + attn_i = attn[i][: mels_pred_len[i], : text_lens[i]].cpu().numpy() + alignment_artifact = ImageArtifact( + id=f"align_{audio_id}", + data=attn_i, + filename=f"{audio_id}_align.png", + x_axis="Audio Frames", + y_axis="Text Tokens", + ) + image_artifacts.append(alignment_artifact) + + if self.log_audio_gta: + # [B, T_audio] + audio_pred, audio_pred_lens = self._generate_audio( + mels=mels_pred, mels_len=mels_pred_len, hop_length=model.preprocessor.hop_length + ) + for i, audio_id in enumerate(audio_ids): + audio_pred_i = audio_pred[i][: audio_pred_lens[i]].cpu().numpy() + audio_artifact = AudioArtifact( + id=f"audio_gta_{audio_id}", + data=audio_pred_i, + filename=f"{audio_id}_gta.wav", + sample_rate=self.vocoder.sample_rate, + ) + audio_artifacts.append(audio_artifact) + + return audio_artifacts, image_artifacts + + def generate_artifacts( + self, model: LightningModule, batch_dict: Dict + ) -> Tuple[List[AudioArtifact], List[ImageArtifact]]: + + audio_artifacts = [] + image_artifacts = [] + audio_filepaths = batch_dict.get("audio_filepaths") + audio_ids = [create_id(p) for p in audio_filepaths] + + if self.log_audio or self.log_spectrogram: + audio_pred, spec_pred = self._generate_predictions(model=model, batch_dict=batch_dict, audio_ids=audio_ids) + audio_artifacts += audio_pred + image_artifacts += spec_pred + + if self.log_audio_gta or self.log_alignment: + audio_gta_pred, alignments = self._generate_gta_predictions( + model=model, batch_dict=batch_dict, audio_ids=audio_ids + ) + audio_artifacts += audio_gta_pred + image_artifacts += alignments + + return audio_artifacts, image_artifacts diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index 3109a9658ba3..3af727a848cf 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -484,6 +484,23 @@ def plot_spectrogram_to_numpy(spectrogram): return data +def create_plot(data, x_axis, y_axis, output_filepath=None): + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(data, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel(x_axis) + plt.ylabel(y_axis) + plt.tight_layout() + + if output_filepath: + plt.savefig(output_filepath, format="png") + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): fig, ax = plt.subplots(figsize=(12, 3)) ax.scatter(