From e3c9234f9148f445f8398fcc86c6db7c25c2b405 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Thu, 13 Feb 2025 17:04:47 -0800 Subject: [PATCH 001/113] add repeat index to help saving pred audio files for each repeat. (#50) moved t5tts script to magpietts Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- scripts/magpietts/infer_and_evaluate.py | 277 ++++++++++++++++++++++++ 1 file changed, 277 insertions(+) create mode 100755 scripts/magpietts/infer_and_evaluate.py diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py new file mode 100755 index 000000000000..3691e8cbcceb --- /dev/null +++ b/scripts/magpietts/infer_and_evaluate.py @@ -0,0 +1,277 @@ +from nemo.collections.tts.models import T5TTS_Model +from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset +from omegaconf.omegaconf import OmegaConf, open_dict +import os +import glob +import torch +import soundfile as sf +import evaluate_generated_audio +import evalset_config +import json +import argparse +import numpy as np +import scipy.stats as stats +import copy +import shutil +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest + +def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0.90): + metrics = {} + for key in metric_keys: + measurements = [m[key] for m in metrics_list] + mean = np.mean(measurements) + std_err = stats.sem(measurements) + + confidence_interval = std_err * stats.t.ppf((1 + confidence) / 2, len(measurements) - 1) + print(f"{key}: {mean} +/- {confidence_interval}") + metrics[key] = "{:.4f} +/- {:.4f}".format(mean, confidence_interval) + return metrics + +def run_inference(hparams_file, checkpoint_file, datasets, out_dir, temperature, topk, codecmodel_path, use_cfg, cfg_scale, batch_size, num_repeats=1): + # import ipdb; ipdb.set_trace() + model_cfg = OmegaConf.load(hparams_file).cfg + + with open_dict(model_cfg): + model_cfg.codecmodel_path = codecmodel_path + if hasattr(model_cfg, 'text_tokenizer'): + # Backward compatibility for models trained with absolute paths in text_tokenizer + model_cfg.text_tokenizer.g2p.phoneme_dict = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + model_cfg.text_tokenizer.g2p.heteronyms = "scripts/tts_dataset_files/heteronyms-052722" + model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0 + model_cfg.train_ds = None + model_cfg.validation_ds = None + + + model = T5TTS_Model(cfg=model_cfg) + model.use_kv_cache_for_inference = True + + # Load weights from checkpoint file + print("Loading weights from checkpoint") + ckpt = torch.load(checkpoint_file) + model.load_state_dict(ckpt['state_dict']) + print("Loaded weights.") + model.cuda() + model.eval() + # import ipdb; ipdb.set_trace() + + checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] + checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}".format(checkpoint_name, temperature, topk, use_cfg, cfg_scale) + dataset_meta_info = evalset_config.dataset_meta_info + for dataset in datasets: + metrics_n_repeated = [] + manifest_records = read_manifest(dataset_meta_info[dataset]['manifest_path']) + for repeat_idx in range(num_repeats): + eval_dir = os.path.join(out_dir, "{}_{}".format(checkpoint_name, dataset)) + audio_dir = os.path.join(eval_dir, "audio") + pred_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") + os.makedirs(pred_audio_dir, exist_ok=True) + language = dataset_meta_info[dataset].get('whisper_language', 'en') + dataset_meta_for_dl = copy.deepcopy(dataset_meta_info[dataset]) + for key in ["whisper_language", "load_cached_codes_if_available"]: + if key in dataset_meta_for_dl: + del dataset_meta_for_dl[key] + + dataset_meta = {dataset: dataset_meta_for_dl} + context_durration_min = model.cfg.get('context_duration_min', 5.0) + context_durration_max = model.cfg.get('context_duration_max', 5.0) + if context_durration_min < 5.0 and context_durration_max > 5.0: + context_durration_min = 5.0 + context_durration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though. + test_dataset = T5TTSDataset( + dataset_meta=dataset_meta, + sample_rate=model_cfg.sample_rate, + min_duration=0.5, + max_duration=20, + codec_model_downsample_factor=model_cfg.codec_model_downsample_factor, + bos_id=model.bos_id, + eos_id=model.eos_id, + context_audio_bos_id=model.context_audio_bos_id, + context_audio_eos_id=model.context_audio_eos_id, + audio_bos_id=model.audio_bos_id, + audio_eos_id=model.audio_eos_id, + num_audio_codebooks=model_cfg.num_audio_codebooks, + prior_scaling_factor=None, + load_cached_codes_if_available=dataset_meta_info[dataset].get('load_cached_codes_if_available', True), + dataset_type='test', + tokenizer_config=None, + load_16khz_audio=model.model_type == 'single_encoder_sv_tts', + use_text_conditioning_tokenizer=model.use_text_conditioning_encoder, + pad_context_text_to_max_duration=model.pad_context_text_to_max_duration, + context_duration_min=context_durration_min, + context_duration_max=context_durration_max, + ) + assert len(test_dataset) == len(manifest_records), "Dataset length and manifest length should be the same. Dataset length: {}, Manifest length: {}".format(len(test_dataset), len(manifest_records)) + test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test') + + test_data_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=batch_size, + collate_fn=test_dataset.collate_fn, + num_workers=2, + shuffle=False + ) + + item_idx = 0 + for bidx, batch in enumerate(test_data_loader): + print("Processing batch {} out of {} of dataset {}".format(bidx, len(test_data_loader), dataset)) + batch_cuda ={} + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch_cuda[key] = batch[key].cuda() + else: + batch_cuda[key] = batch[key] + + import time + st = time.time() + predicted_audio, predicted_audio_lens, _, _ = model.infer_batch(batch_cuda, max_decoder_steps=440, temperature=temperature, topk=topk, use_cfg=use_cfg, cfg_scale=cfg_scale) + et = time.time() + print(f"Time taken for inference: {et-st}", predicted_audio.size()) + for idx in range(predicted_audio.size(0)): + predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() + predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]] + audio_path = os.path.join(pred_audio_dir, f"predicted_audio_{item_idx}.wav") + sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate) + context_audio_path = manifest_records[item_idx].get('context_audio_filepath', None) + target_audio_path = manifest_records[item_idx].get('audio_filepath', None) + if context_audio_path is not None: + context_audio_path = os.path.join(dataset_meta_info[dataset]['audio_dir'], context_audio_path) + if target_audio_path is not None: + target_audio_path = os.path.join(dataset_meta_info[dataset]['audio_dir'], target_audio_path) + if os.path.exists(context_audio_path): + shutil.copy(context_audio_path, os.path.join(audio_dir, f"context_audio_{item_idx}.wav")) + if os.path.exists(target_audio_path): + shutil.copy(target_audio_path, os.path.join(audio_dir, f"target_audio_{item_idx}.wav")) + item_idx += 1 + + metrics, filewise_metrics = evaluate_generated_audio.evaluate( + dataset_meta[dataset]['manifest_path'], + dataset_meta[dataset]['audio_dir'], + pred_audio_dir, + language=language, + ) + metrics_n_repeated.append(metrics) + with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: + json.dump(metrics, f, indent=4) + + with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f: + # Indent for better readability + json.dump(filewise_metrics, f, indent=4) + + all_experiment_csv = os.path.join(out_dir, "all_experiment_metrics.csv") + if not os.path.exists(all_experiment_csv): + with open(all_experiment_csv, "w") as f: + f.write("checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative\n") + with open(all_experiment_csv, "a") as f: + f.write(f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']}\n") + print(f"Wrote metrics for {checkpoint_name} and {dataset} to {all_experiment_csv}") + + metric_keys = ['cer_filewise_avg', 'wer_filewise_avg', 'cer_cumulative', 'wer_cumulative', + 'ssim_pred_gt_avg', 'ssim_pred_context_avg', 'ssim_gt_context_avg', + 'ssim_pred_gt_avg_alternate', 'ssim_pred_context_avg_alternate', 'ssim_gt_context_avg_alternate', + 'cer_gt_audio_cumulative', 'wer_gt_audio_cumulative' + ] + metrics_mean_ci = compute_mean_and_confidence_interval(metrics_n_repeated, metric_keys) + all_experiment_csv_with_ci = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") + if not os.path.exists(all_experiment_csv_with_ci): + with open(all_experiment_csv_with_ci, "w") as f: + f.write("checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative\n") + with open(all_experiment_csv_with_ci, "a") as f: + f.write(f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']}\n") + print(f"Wrote metrics with CI for {checkpoint_name} and {dataset} to {all_experiment_csv_with_ci}") + + +def main(): + parser = argparse.ArgumentParser(description='Experiment Evaluation') + parser.add_argument('--hparams_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_hparams.yaml,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_hparams.yaml") + parser.add_argument('--checkpoint_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_epoch302.ckpt,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_epoch305.ckpt") + parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo") + parser.add_argument('--datasets', type=str, default="libri_seen_test,libri_unseen_test") + parser.add_argument('--base_exp_dir', type=str, default="/datap/misc/eosmount4/AllKernselSize3/NewTransformer") + parser.add_argument('--draco_exp_dir', type=str, default="/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/gitrepos/experiments/NewT5TTS_FixedPosEmb/AllKernselSize3/NewTransformer") + parser.add_argument('--server_address', type=str, default="pneekhara@login-eos02.eos.clusters.nvidia.com") + parser.add_argument('--exp_names', type=str, default="multiencoder_small_sp_ks3_lnormapplied") + parser.add_argument('--local_ckpt_dir', type=str, default="/datap/misc/continuouscheckpoints_fixedposembrough") + parser.add_argument('--out_dir', type=str, default="/datap/misc/ContinuousEvalResults/NewTransformerKoelTTS") + parser.add_argument('--temperature', type=float, default=0.6) + parser.add_argument('--use_cfg', action='store_true') + parser.add_argument('--cfg_scale', type=float, default=1.0) + parser.add_argument('--topk', type=int, default=80) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--num_repeats', type=int, default=1) + args = parser.parse_args() + + if (args.hparams_files is not None) and (args.checkpoint_files is not None) and (args.hparams_files != "null"): + hparam_files = args.hparams_files.split(",") + checkpoint_files = args.checkpoint_files.split(",") + print("Running inference for hparams files: ", hparam_files) + print("Running inference for checkpoint files: ", checkpoint_files) + assert len(hparam_files) == len(checkpoint_files), "Number of hparams files and checkpoint files should be the same." + for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): + run_inference( + hparams_file, + checkpoint_file, + args.datasets.split(","), + args.out_dir, + args.temperature, + args.topk, + args.codecmodel_path, + args.use_cfg, + args.cfg_scale, + args.batch_size, + args.num_repeats + ) + return + else: + BASE_EXP_DIR = args.base_exp_dir + DRACO_EXP_DIR = args.draco_exp_dir + # Mount DRACO_EXP_DIR to BASE_EXP_DIR as follows: + # sshfs -o allow_other pneekhara@draco-oci-dc-02.draco-oci-iad.nvidia.com:/lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/experiments/NewT5AllFixedFresh /datap/misc/dracomount/ + if args.exp_names is None: + exp_names = os.listdir(BASE_EXP_DIR) + else: + exp_names = args.exp_names.split(",") + + for exp_name in exp_names: + exp_dir = os.path.join(BASE_EXP_DIR, exp_name) + # recurisvely look for hparams.yaml + try: + hparams_file = glob.glob(f"{exp_dir}/**/hparams.yaml", recursive=True)[0] + checkpoints_dir = glob.glob(f"{exp_dir}/**/checkpoints", recursive=True)[0] + last_checkpoint = (glob.glob(f"{checkpoints_dir}/*last.ckpt"))[0] + except: + print(f"Skipping experiment {exp_name} as hparams or last checkpoint not found.") + continue + last_checkpoint_path_draco = last_checkpoint.replace(BASE_EXP_DIR, DRACO_EXP_DIR) + epoch_num = last_checkpoint.split("epoch=")[1].split("-")[0] + + checkpoint_copy_path = os.path.join(args.local_ckpt_dir, f"{exp_name}_epoch_{epoch_num}.ckpt") + hparams_copy_path = os.path.join(args.local_ckpt_dir, f"{exp_name}_hparams.yaml") + + scp_command = f"scp {args.server_address}:{last_checkpoint_path_draco} {checkpoint_copy_path}" + print(f"Running command: {scp_command}") + os.system(scp_command) + print("Copied checkpoint.") + hparams_path_draco = hparams_file.replace(BASE_EXP_DIR, DRACO_EXP_DIR) + scp_command_hparams = f"scp {args.server_address}:{hparams_path_draco} {hparams_copy_path}" + print(f"Running command: {scp_command_hparams}") + os.system(scp_command_hparams) + print("Copied hparams file.") + # import ipdb; ipdb.set_trace() + print("Hparams file path: ", hparams_copy_path) + print("Checkpoint file path: ", checkpoint_copy_path) + run_inference( + hparams_copy_path, + checkpoint_copy_path, + args.datasets.split(","), + args.out_dir, + args.temperature, + args.topk, + args.codecmodel_path, + args.use_cfg, + args.cfg_scale, + args.batch_size + ) + + +if __name__ == '__main__': + main() \ No newline at end of file From d1747862d07bdfc45e732e485c25eb9763b83225 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Tue, 18 Feb 2025 15:28:17 -0800 Subject: [PATCH 002/113] fix: make confidence level configurable. (#51) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- scripts/magpietts/infer_and_evaluate.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 3691e8cbcceb..153ba40a0579 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -27,7 +27,7 @@ def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0 metrics[key] = "{:.4f} +/- {:.4f}".format(mean, confidence_interval) return metrics -def run_inference(hparams_file, checkpoint_file, datasets, out_dir, temperature, topk, codecmodel_path, use_cfg, cfg_scale, batch_size, num_repeats=1): +def run_inference(hparams_file, checkpoint_file, datasets, out_dir, temperature, topk, codecmodel_path, use_cfg, cfg_scale, batch_size, num_repeats=1, confidence_level=0.95): # import ipdb; ipdb.set_trace() model_cfg = OmegaConf.load(hparams_file).cfg @@ -170,7 +170,7 @@ def run_inference(hparams_file, checkpoint_file, datasets, out_dir, temperature, 'ssim_pred_gt_avg_alternate', 'ssim_pred_context_avg_alternate', 'ssim_gt_context_avg_alternate', 'cer_gt_audio_cumulative', 'wer_gt_audio_cumulative' ] - metrics_mean_ci = compute_mean_and_confidence_interval(metrics_n_repeated, metric_keys) + metrics_mean_ci = compute_mean_and_confidence_interval(metrics_n_repeated, metric_keys, confidence=confidence_level) all_experiment_csv_with_ci = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") if not os.path.exists(all_experiment_csv_with_ci): with open(all_experiment_csv_with_ci, "w") as f: @@ -198,6 +198,7 @@ def main(): parser.add_argument('--topk', type=int, default=80) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--num_repeats', type=int, default=1) + parser.add_argument('--confidence_level', type=float, default=0.95) args = parser.parse_args() if (args.hparams_files is not None) and (args.checkpoint_files is not None) and (args.hparams_files != "null"): @@ -218,7 +219,8 @@ def main(): args.use_cfg, args.cfg_scale, args.batch_size, - args.num_repeats + args.num_repeats, + args.confidence_level, ) return else: @@ -269,7 +271,9 @@ def main(): args.codecmodel_path, args.use_cfg, args.cfg_scale, - args.batch_size + args.batch_size, + num_repeats=args.num_repeats, + confidence_level=args.confidence_level, ) From 0949d57bb2be8a3230ed044ffe364045fcb6bca9 Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Tue, 18 Feb 2025 19:13:01 -0800 Subject: [PATCH 003/113] Inference prior; updated by Jason * wip Signed-off-by: Paarth Neekhara * attn prior inference implementation Signed-off-by: Paarth Neekhara * more hacks Signed-off-by: Paarth Neekhara * minor tweaks Signed-off-by: Paarth Neekhara * clean ups and make text attention strictly monotonic at inference Signed-off-by: Paarth Neekhara * more updates Signed-off-by: Paarth Neekhara * minor tweaks Signed-off-by: Paarth Neekhara * compute head wise attention maps Signed-off-by: Paarth Neekhara * configurable ctc prior layers during training Signed-off-by: Paarth Neekhara * log only ctc prior layers on tensorboard Signed-off-by: Paarth Neekhara * add layerwise logging Signed-off-by: Paarth Neekhara * more configurable inference Signed-off-by: Paarth Neekhara * more conifigs Signed-off-by: Paarth Neekhara * updated end prediction logic as per discussion with roy Signed-off-by: Paarth Neekhara * DPO preference pair creations: add option to choose min length * Cleanup * handle cases where predicted codes are very small, havent tested but should work Signed-off-by: Paarth Neekhara * undo predicted len change since it is not needed Signed-off-by: Paarth Neekhara * clean up notebook Signed-off-by: Paarth Neekhara --------- Signed-off-by: Paarth Neekhara Co-authored-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 226 +++++++++++-- .../tts/modules/transformer_2501.py | 13 +- .../magpietts/dpo/create_preference_pairs.py | 12 +- scripts/magpietts/evalset_config.py | 230 +++++++++++++ scripts/magpietts/infer_and_evaluate.py | 121 +++++-- t5tts_inference.ipynb | 311 ++++++++++++++++++ 6 files changed, 851 insertions(+), 62 deletions(-) create mode 100644 scripts/magpietts/evalset_config.py create mode 100644 t5tts_inference.ipynb diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 9f0ae9c45d5f..4c9c5d84ac3f 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -14,6 +14,7 @@ import copy import json import os +import random import string from typing import List @@ -374,13 +375,18 @@ def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): return all_preds - def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80): + def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={}): # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep all_preds = [] for idx in range(self.cfg.num_audio_codebooks): si = idx * self.cfg.num_audio_tokens_per_codebook ei = si + self.cfg.num_audio_tokens_per_codebook codebook_logits = all_code_logits_t[:, si:ei] # (B, num_tokens_per_codebook) + for item_idx in unfinished_items: + codebook_logits[item_idx, self.audio_eos_id] = float('-inf') + for item_idx in finished_items: + codebook_logits[item_idx, :] = float('-inf') + codebook_logits[item_idx, self.audio_eos_id] = 0.0 codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -462,6 +468,12 @@ def scale_prior(self, prior, global_step): if global_step < prior_scaledown_start_step: return prior elif global_step >= prior_end_step: + if self.cfg.get('prior_always_applied', False): + # Added this so that model always knows how to work with and without the prior + if random.random() < 0.5: + return prior + else: + return None return None else: with torch.no_grad(): @@ -596,6 +608,16 @@ def prepare_context_tensors(self, batch): else: raise ValueError(f"Unsupported model type {self.model_type}") + if attn_prior is not None and self.cfg.get('ctc_prior_layer_ids', None) is not None: + ctc_prior_layer_ids = self.cfg.ctc_prior_layer_ids + # Convert prior to a list of tensors, one for each layer + # Set None for layers not in ctc_prior_layer_ids + if self.model_type == 'multi_encoder_context_tts': + text_attn_prior = [attn_prior[0] if layer_idx in ctc_prior_layer_ids else None for layer_idx in range(self.cfg.t5_decoder.n_layers) ] + attn_prior = [text_attn_prior, attn_prior[1]] + else: + attn_prior = [attn_prior if layer_idx in ctc_prior_layer_ids else None for layer_idx in range(self.cfg.t5_decoder.n_layers) ] + return { 'cond': cond, 'cond_mask': cond_mask, @@ -721,11 +743,8 @@ def process_batch(self, batch, mode="train"): alignment_loss = None if self.cfg.alignment_loss_scale > 0.0 and not disable_alignment_loss: text_lens = context_tensors['text_lens'] - cross_attention_scores = [ - attn['cross_attn_probabilities'][1] - for layer_idx, attn in enumerate(attn_info) - if layer_idx in self.transcript_decoder_layers - ] + ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers) + cross_attention_scores = [attn['cross_attn_probabilities'][1] for layer_idx, attn in enumerate(attn_info) if layer_idx in ctc_prior_layer_ids] alignment_loss = self.compute_alignment_loss( cross_attention_scores, text_lens, audio_codes_lens_target, dec_context_size ) @@ -789,11 +808,8 @@ def validation_step(self, batch, batch_idx): and len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1 ): # cross_attn_probabilities only returned when not using flash attention - cross_attention_probs = [ - attn['cross_attn_probabilities'][0] - for layer_idx, attn in enumerate(attn_info) - if layer_idx in self.transcript_decoder_layers - ] + ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers) + cross_attention_probs = [attn['cross_attn_probabilities'][0] for layer_idx, attn in enumerate(attn_info) if layer_idx in ctc_prior_layer_ids] self.log_attention_probs( cross_attention_probs, audio_codes_lens_target, @@ -801,6 +817,9 @@ def validation_step(self, batch, batch_idx): prefix="val_", dec_context_size=dec_context_size, ) + for layer_idx in self.transcript_decoder_layers: + cross_attention_probs = [ attn_info[layer_idx]['cross_attn_probabilities'][0] ] + self.log_attention_probs(cross_attention_probs, audio_codes_lens_target, text_lens, prefix=f"val_layer_{layer_idx}_", dec_context_size=dec_context_size) val_output = { 'val_loss': loss, @@ -811,7 +830,42 @@ def validation_step(self, batch, batch_idx): return val_output - def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, use_cfg=False, cfg_scale=1.0): + def get_cross_attention_scores(self, attn_probs, filter_layers=None): + """ + Returns the cross attention probabilities for the last audio timestep + """ + mean_cross_attn_scores = [] + all_heads_cross_attn_scores = [] + for lidx, layerwise_attn_prob in enumerate(attn_probs): + if (filter_layers is not None and lidx not in filter_layers) or (lidx not in self.transcript_decoder_layers): + continue + cross_attn_prob = layerwise_attn_prob['cross_attn_probabilities'][0] # B, H, audio_timesteps, text_timesteps + mean_cross_attn_scores.append(cross_attn_prob.mean(dim=1)) # B, audio_timesteps, text_timesteps + for head_idx in range(cross_attn_prob.size(1)): + all_heads_cross_attn_scores.append(cross_attn_prob[:, head_idx, -1, :]) # B, text_timesteps + + mean_cross_attn_scores = torch.stack(mean_cross_attn_scores, dim=1) # B, L, audio_timesteps, text_timesteps + mean_cross_attn_scores = mean_cross_attn_scores.mean(dim=1) # B, audio_timesteps, text_timesteps + last_audio_timestep_scores = mean_cross_attn_scores[:, -1, :] # B, text_timesteps + return last_audio_timestep_scores, all_heads_cross_attn_scores + + def infer_batch( + self, + batch, + max_decoder_steps=500, + temperature=0.7, + topk=80, + use_cfg=False, + cfg_scale=1.0, + return_cross_attn_probs=False, + apply_attention_prior=False, + prior_epsilon=1e-5, + lookahead_window_size=10, + estimate_alignment_from_layers=None, + apply_prior_to_layers=None, + start_prior_after_n_audio_steps=10, + compute_all_heads_attn_maps=False, + ): with torch.no_grad(): self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference) @@ -837,6 +891,13 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us ) ) + cross_attention_scores_all_timesteps = [] + all_heads_cross_attn_scores_all_timesteps = [] + _attn_prior = None + unfinished_texts = {} + finished_texts_counter = {} + attended_timestep_counter = [{} for _ in range(text.size(0))] + last_attended_timesteps = [[1 for _ in range(text.size(0))]] # Maintain a list of attended timesteps as we predict audio for each batch item for idx in range(max_decoder_steps): if idx % 20 == 0: print(f"Decoding timestep {idx}") @@ -850,7 +911,18 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us _audio_codes_embedded = audio_codes_embedded _audio_codes_mask = audio_codes_mask + if apply_prior_to_layers is not None: + attn_prior = [None for _ in range(self.cfg.t5_decoder.n_layers)] + for layer_idx in apply_prior_to_layers: + attn_prior[layer_idx] = _attn_prior + else: + attn_prior = _attn_prior + + if self.model_type == 'multi_encoder_context_tts': + attn_prior = [attn_prior, None] + if use_cfg: + # import ipdb; ipdb.set_trace() batch_size = audio_codes_embedded.size(0) # Combine conditional and unconditional inputs into one batch if isinstance(context_tensors['cond'], list): @@ -877,34 +949,104 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us dummy_addition_dec_mask ) - combined_logits, _ = self.forward( + combined_logits, attn_probs = self.forward( dec_input_embedded=cfg_audio_codes_embedded, dec_input_mask=cfg_audio_codes_mask, cond=cfg_cond, cond_mask=cfg_cond_mask, - attn_prior=None, - multi_encoder_mapping=context_tensors['multi_encoder_mapping'], + attn_prior=attn_prior, + multi_encoder_mapping=context_tensors['multi_encoder_mapping'] ) cond_logits = combined_logits[:batch_size] uncond_logits = combined_logits[batch_size:] all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits else: - all_code_logits, _ = self.forward( + batch_size = audio_codes_embedded.size(0) + all_code_logits, attn_probs = self.forward( dec_input_embedded=_audio_codes_embedded, dec_input_mask=_audio_codes_mask, cond=context_tensors['cond'], cond_mask=context_tensors['cond_mask'], - attn_prior=None, - multi_encoder_mapping=context_tensors['multi_encoder_mapping'], + attn_prior=attn_prior, + multi_encoder_mapping=context_tensors['multi_encoder_mapping'] ) - all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) - audio_codes_next = self.sample_codes_from_logits( - all_code_logits_t, temperature=temperature, topk=topk - ) # (B, num_codebooks) - all_codes_next_argmax = self.sample_codes_from_logits( - all_code_logits_t, temperature=0.01 - ) # (B, num_codebooks) + + if return_cross_attn_probs or apply_attention_prior: + cross_attention_scores, all_heads_cross_attn_scores = self.get_cross_attention_scores(attn_probs) # B, text_timesteps + alignment_attention_scores = cross_attention_scores + if estimate_alignment_from_layers is not None: + alignment_attention_scores, _ = self.get_cross_attention_scores(attn_probs, filter_layers=estimate_alignment_from_layers) # B, text_timesteps + text_time_step_attended = [] + for bidx in range(batch_size): + last_attended_timestep = last_attended_timesteps[-1][bidx] + if attended_timestep_counter[bidx].get(last_attended_timestep, 0) >= 8: + # This is probably an attention sink! Move to the next timestep + last_attended_timestep += 1 + window_size = lookahead_window_size + window_end = min(last_attended_timestep + window_size, context_tensors['text_lens'][bidx] - 3) # Ignore the last 3 timesteps + item_attention_scores = alignment_attention_scores[bidx,last_attended_timestep:window_end] + if item_attention_scores.size(0) == 0: + # This means the sentence has ended + attended_timestep = context_tensors['text_lens'][bidx] - 1 + else: + attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep + text_time_step_attended.append(attended_timestep) + attended_timestep_counter[bidx][attended_timestep] = attended_timestep_counter[bidx].get(attended_timestep, 0) + 1 + + last_attended_timesteps.append(text_time_step_attended) + cross_attention_scores_all_timesteps.append(cross_attention_scores) + all_heads_cross_attn_scores_all_timesteps.append(all_heads_cross_attn_scores) + # if idx % 20 == 0: + # print("At timesteps", idx, text_time_step_attended, context_tensors['text_lens']) + + if apply_attention_prior and idx >= start_prior_after_n_audio_steps: + eps = prior_epsilon + # Attn prior for the next timestep + _attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + eps + _attn_prior = _attn_prior.to(cross_attention_scores.device) + for bidx in range(cross_attention_scores.shape[0]): + if bidx < batch_size: + _text_len = context_tensors['text_lens'][bidx] + if context_tensors['text_lens'][bidx] <= 5: + # Very short sentences, No Prior + _attn_prior[bidx, 0, :] = 1.0 + else: + # _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-2)] = 0.1 # Slight exposure to history for better pronounciation. Not very important. + _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-1)] = 0.2 # Slight exposure to history for better pronounciation. Not very important. + _attn_prior[bidx, 0, text_time_step_attended[bidx]] = 0.8 # Slightly bias to continue moving forward. Not very important. + _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+1, _text_len - 1) ] = 1.0 + _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+2, _text_len - 1) ] = 0.8 + + # Penalize timesteps that have been attended to more than 10 times + for _timestep in attended_timestep_counter[bidx]: + if attended_timestep_counter[bidx][_timestep] >= 10: + # This means the timestep has been attended to more than 10 times (To avoid getting stuck) + _attn_prior[bidx, 0, _timestep] = eps + + unfinished_texts[bidx] = False + if text_time_step_attended[bidx] < context_tensors['text_lens'][bidx] - 3: + # This means the sentence has definitely not ended + if bidx not in end_indices: + unfinished_texts[bidx] = True + + if text_time_step_attended[bidx] >= context_tensors['text_lens'][bidx] - 5 or bidx in end_indices: + if bidx not in finished_texts_counter: + finished_texts_counter[bidx] = 0 + + for key in finished_texts_counter: + finished_texts_counter[key] += 1 + if finished_texts_counter[key] > 10: + # We should allow EOS to be predicted now. + unfinished_texts[bidx] = False + + finished_items = {k: v for k, v in finished_texts_counter.items() if v >= 20} # Items that have been close to the end for atleast 20 timesteps + unifinished_items = {k: v for k, v in unfinished_texts.items() if v} + + all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) + audio_codes_next = self.sample_codes_from_logits(all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) + all_codes_next_argmax = self.sample_codes_from_logits(all_code_logits_t, temperature=0.01, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) + for item_idx in range(all_codes_next_argmax.size(0)): if item_idx not in end_indices: @@ -920,18 +1062,44 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us ) # (B, C, T') audio_codes_lens = audio_codes_lens + 1 audio_codes_mask = get_mask_from_lengths(audio_codes_lens) - if len(end_indices) == text.size(0): + if len(end_indices) == text.size(0) and len(all_predictions) >= 4: + # Codec must be of atleast 4 timesteps to be decoded properly print("All ends reached") break - predicted_codes = torch.stack(all_predictions, dim=-1) # (B, num_codebooks, T') - predicted_lens = [end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0))] + + predicted_lens = [end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0))] # Ensure that the codec is atleast of length 4 predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long() predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens) torch.cuda.empty_cache() - return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens + if return_cross_attn_probs: + cross_attention_scores_all_timesteps = torch.stack(cross_attention_scores_all_timesteps, dim=2) # B, text_timesteps, T' + + headwise_cross_attention_scores_all_timesteps = [] + for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): + head_cross_attention_all_timesteps = torch.stack([x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2) # B, text_timesteps, T' + headwise_cross_attention_scores_all_timesteps.append(head_cross_attention_all_timesteps) + + cross_attention_maps = [] + headwise_cross_attention_maps = [] + for bidx in range(predicted_audio.size(0)): + item_cross_attention_scores = cross_attention_scores_all_timesteps[bidx,:context_tensors['text_lens'][bidx],:predicted_codes_lens[bidx]] + cross_attn_np = plot_alignment_to_numpy(item_cross_attention_scores.cpu().numpy()) + cross_attention_maps.append(cross_attn_np) + item_all_head_cross_attn_maps = [] + if compute_all_heads_attn_maps: + for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): + item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][bidx,:context_tensors['text_lens'][bidx],:predicted_codes_lens[bidx]] + headwise_cross_attn_np = plot_alignment_to_numpy(item_headwise_cross_attention_scores.cpu().numpy()) + item_all_head_cross_attn_maps.append(headwise_cross_attn_np) + headwise_cross_attention_maps.append(item_all_head_cross_attn_maps) + + return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, cross_attention_maps, headwise_cross_attention_maps + else: + # For backward compatibility + return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens def test_step(self, batch, batch_idx): with torch.no_grad(): diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index dc5debc04f39..28f498126d35 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -648,13 +648,22 @@ def _get_layer_inputs( if multi_encoder_mapping[idx] is None: return None, None, None else: + _attn_prior = attn_prior[multi_encoder_mapping[idx]] if attn_prior is not None else None + if isinstance(_attn_prior, list): + # @pneekhara: This means, we are passing layerwise attn_prior + _attn_prior = _attn_prior[idx] return ( cond[multi_encoder_mapping[idx]], cond_mask[multi_encoder_mapping[idx]] if cond_mask is not None else None, - attn_prior[multi_encoder_mapping[idx]] if attn_prior is not None else None, + _attn_prior, ) else: - return cond, cond_mask, attn_prior + if isinstance(attn_prior, list): + # @pneekhara: This means, we are passing layerwise attn_prior + _attn_prior = attn_prior[idx] + else: + _attn_prior = attn_prior + return cond, cond_mask, _attn_prior def forward( self, diff --git a/scripts/magpietts/dpo/create_preference_pairs.py b/scripts/magpietts/dpo/create_preference_pairs.py index 4d8ed40f3bb0..99d212c6c496 100644 --- a/scripts/magpietts/dpo/create_preference_pairs.py +++ b/scripts/magpietts/dpo/create_preference_pairs.py @@ -43,6 +43,7 @@ def main(): ) parser.add_argument("--group_size", type=int, default=4) parser.add_argument("--cer_threshold", type=float, default=0.02) + parser.add_argument("--min_length_threshold", type=float, default=1.5, help="Minimum length permitted. Set this shorter to allow very short sentences (which can be useful for DPO tuning.") parser.add_argument("--val_size", type=int, default=64) args = parser.parse_args() @@ -82,9 +83,7 @@ def main(): all_best_records, all_worst_records = create_chosen_rejected_records(records, group_size, num_chosen_per_group) print("Len all_best_records: ", len(all_best_records)) print("Len all_worst_records: ", len(all_worst_records)) - best_records, worst_records = filter_best_and_worst_records( - all_best_records, all_worst_records, args.cer_threshold - ) + best_records, worst_records = filter_best_and_worst_records(all_best_records, all_worst_records, args.cer_threshold, args.min_length_threshold) print("Len filtered best_records: ", len(best_records)) print("Len filtered worst_records: ", len(worst_records)) worst_records = normalize_rejected_rewards(worst_records) @@ -302,8 +301,7 @@ def create_chosen_rejected_records(records_orig, group_size=6, num_chosen_per_gr print(f"Skipped {num_skipped} records due to invalid entries.") return best_records, worst_records - -def filter_best_and_worst_records(best_records, worst_records, cer_threshold=0.02): +def filter_best_and_worst_records(best_records, worst_records, cer_threshold=0.02, min_length_threshold=1.5): ridx = 0 filtered_best_records = [] filtered_worst_records = [] @@ -315,9 +313,7 @@ def filter_best_and_worst_records(best_records, worst_records, cer_threshold=0.0 best_record = best_records[ridx] if best_record['cer_gts'] < cer_threshold: worst_record = worst_records[ridx] - if (worst_record['duration'] > 19.0 or best_record['duration'] > 19.0) or ( - worst_record['duration'] < 1.5 or best_record['duration'] < 1.5 - ): + if (worst_record['duration'] > 19.0 or best_record['duration'] > 19.0) or (worst_record['duration'] < min_length_threshold or best_record['duration'] < min_length_threshold): skipped_records += 1 ridx += 1 continue diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py new file mode 100644 index 000000000000..897add3b4a44 --- /dev/null +++ b/scripts/magpietts/evalset_config.py @@ -0,0 +1,230 @@ +dataset_meta_info = { + 'vctk': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths.json', + 'audio_dir' : '/datap/misc/Datasets/VCTK-Corpus', + 'feature_dir' : '/datap/misc/Datasets/VCTK-Corpus', + }, + 'riva_challenging': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/challengingLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json', + 'audio_dir' : '/datap/misc/Datasets/riva', + 'feature_dir' : '/datap/misc/Datasets/riva', + }, + 'riva_challenging_shehzeen': { + 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/challengingLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths_v2.json', + 'audio_dir' : '/Data/RivaData/riva', + 'feature_dir' : '/Data/RivaData/riva', + }, + 'riva_challenging_nozeros': { + # 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/riva_challenging_nozeros.json', + 'manifest_path': '/home/pneekhara/2023/SimpleT5NeMo/manifests/riva_challenging_filtered.json', + 'audio_dir' : '/datap/misc/Datasets/riva', + 'feature_dir' : '/datap/misc/Datasets/riva', + }, + 'libri_val': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360_val.json', + 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + }, + 'libri_val_shehzeen': { + 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/libri360_val.json', + 'audio_dir' : '/Data/LibriTTS', + 'feature_dir' : '/Data/LibriTTS', + }, + 'libri_unseen_test': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/test_clean_withContextAudioPaths.json', + 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + }, + 'libri_unseen_test_shehzeen_phoneme': { + 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', + 'audio_dir' : '/Data/LibriTTS', + 'feature_dir' : '/Data/LibriTTS', + 'tokenizer_names': ['english_phoneme'], + }, + 'libri_unseen_test_shehzeen_sep_char': { + 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', + 'audio_dir' : '/Data/LibriTTS', + 'feature_dir' : '/Data/LibriTTS', + 'tokenizer_names': ['english_chartokenizer'], + }, + 'libri_unseen_test_shehzeen_shared_char': { + 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', + 'audio_dir' : '/Data/LibriTTS', + 'feature_dir' : '/Data/LibriTTS', + 'tokenizer_names': ['chartokenizer'], + }, + 'libri_unseen_test_shehzeen_sp': { + 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', + 'audio_dir' : '/Data/LibriTTS', + 'feature_dir' : '/Data/LibriTTS', + 'tokenizer_names': ['multilingual_sentencepiece'], + }, + 'libri_unseen_test_shehzeen': { + 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', + 'audio_dir' : '/Data/LibriTTS', + 'feature_dir' : '/Data/LibriTTS', + }, + 'libri_seen_test_v2': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri_seen_evalset_from_testclean_v2.json', + 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + }, + 'libri_seen_test_v2_shehzeen': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri_seen_evalset_from_testclean_v2.json', + 'audio_dir' : '/Data/LibriTTS', + 'feature_dir' : '/Data/LibriTTS', + }, + 'libri_unseen_val': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/dev_clean_withContextAudioPaths_evalset.json', + 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + }, + 'spanish_cml_phoneme': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_spanish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', + 'tokenizer_names': ['spanish_phoneme'], + 'whisper_language': 'es' + }, + 'spanish_cml_sep_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_spanish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', + 'tokenizer_names': ['spanish_chartokenizer'], + 'whisper_language': 'es' + }, + 'spanish_cml_shared_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_spanish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', + 'tokenizer_names': ['chartokenizer'], + 'whisper_language': 'es' + }, + 'spanish_cml_sp': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_spanish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', + 'tokenizer_names': ['multilingual_sentencepiece'], + 'whisper_language': 'es' + }, + 'german_cml_phoneme': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_german_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_german_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_german_v0.1', + 'tokenizer_names': ['german_phoneme'], + 'whisper_language': 'de', + 'load_cached_codes_if_available': False + }, + 'german_cml_sep_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_german_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_german_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_german_v0.1', + 'tokenizer_names': ['german_chartokenizer'], + 'whisper_language': 'de', + 'load_cached_codes_if_available': False + }, + 'german_cml_shared_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_german_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_german_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_german_v0.1', + 'tokenizer_names': ['chartokenizer'], + 'whisper_language': 'de', + 'load_cached_codes_if_available': False + }, + 'german_cml_sp': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_german_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_german_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_german_v0.1', + 'tokenizer_names': ['multilingual_sentencepiece'], + 'whisper_language': 'de', + 'load_cached_codes_if_available': False + }, + 'french_cml_sep_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_french_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_french_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_french_v0.1', + 'tokenizer_names': ['french_chartokenizer'], + 'whisper_language': 'fr', + 'load_cached_codes_if_available': False + }, + 'french_cml_shared_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_french_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_french_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_french_v0.1', + 'tokenizer_names': ['chartokenizer'], + 'whisper_language': 'fr', + 'load_cached_codes_if_available': False + }, + 'french_cml_sp': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_french_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_french_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_french_v0.1', + 'tokenizer_names': ['multilingual_sentencepiece'], + 'whisper_language': 'fr', + 'load_cached_codes_if_available': False + }, + 'italian_cml_sep_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_italian_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', + 'tokenizer_names': ['italian_chartokenizer'], + 'whisper_language': 'it', + 'load_cached_codes_if_available': False + }, + 'italian_cml_shared_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_italian_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', + 'tokenizer_names': ['chartokenizer'], + 'whisper_language': 'it', + 'load_cached_codes_if_available': False + }, + 'italian_cml_sp': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_italian_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', + 'tokenizer_names': ['multilingual_sentencepiece'], + 'whisper_language': 'it', + 'load_cached_codes_if_available': False + }, + 'dutch_cml_sep_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_dutch_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', + 'tokenizer_names': ['dutch_chartokenizer'], + 'whisper_language': 'nl', + 'load_cached_codes_if_available': False + }, + 'dutch_cml_shared_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_dutch_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', + 'tokenizer_names': ['chartokenizer'], + 'whisper_language': 'nl', + 'load_cached_codes_if_available': False + }, + 'dutch_cml_sp': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_dutch_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', + 'tokenizer_names': ['multilingual_sentencepiece'], + 'whisper_language': 'nl', + 'load_cached_codes_if_available': False + }, + 'portuguese_cml_sep_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_portuguese_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_portuguese_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_portuguese_v0.1', + 'tokenizer_names': ['portuguese_chartokenizer'], + 'whisper_language': 'pt', + 'load_cached_codes_if_available': False + }, + 'polish_cml_sep_char': { + 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_polish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_polish_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_polish_v0.1', + 'tokenizer_names': ['polish_chartokenizer'], + 'whisper_language': 'pl', + 'load_cached_codes_if_available': False + }, +} \ No newline at end of file diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 153ba40a0579..315c6a2e832a 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -14,6 +14,7 @@ import copy import shutil from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from PIL import Image def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0.90): metrics = {} @@ -27,7 +28,26 @@ def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0 metrics[key] = "{:.4f} +/- {:.4f}".format(mean, confidence_interval) return metrics -def run_inference(hparams_file, checkpoint_file, datasets, out_dir, temperature, topk, codecmodel_path, use_cfg, cfg_scale, batch_size, num_repeats=1, confidence_level=0.95): +def run_inference( + hparams_file, + checkpoint_file, + datasets, + out_dir, + temperature, + topk, + codecmodel_path, + use_cfg, + cfg_scale, + batch_size, + num_repeats=1, + apply_attention_prior=False, + attention_prior_epsilon=1e-3, + attention_prior_lookahead_window=10, + estimate_alignment_from_layers=None, + apply_prior_to_layers=None, + start_prior_after_n_audio_steps=10, + confidence_level=0.95 + ): # import ipdb; ipdb.set_trace() model_cfg = OmegaConf.load(hparams_file).cfg @@ -55,7 +75,19 @@ def run_inference(hparams_file, checkpoint_file, datasets, out_dir, temperature, # import ipdb; ipdb.set_trace() checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] - checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}".format(checkpoint_name, temperature, topk, use_cfg, cfg_scale) + checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}".format( + checkpoint_name, + temperature, + topk, + use_cfg, + cfg_scale, + apply_attention_prior, + attention_prior_epsilon, + attention_prior_lookahead_window, + start_prior_after_n_audio_steps, + "".join([str(l) for l in estimate_alignment_from_layers]) if estimate_alignment_from_layers is not None else "None", + "".join([str(l) for l in apply_prior_to_layers]) if apply_prior_to_layers is not None else "None" + ) dataset_meta_info = evalset_config.dataset_meta_info for dataset in datasets: metrics_n_repeated = [] @@ -123,10 +155,28 @@ def run_inference(hparams_file, checkpoint_file, datasets, out_dir, temperature, import time st = time.time() - predicted_audio, predicted_audio_lens, _, _ = model.infer_batch(batch_cuda, max_decoder_steps=440, temperature=temperature, topk=topk, use_cfg=use_cfg, cfg_scale=cfg_scale) + predicted_audio, predicted_audio_lens, _, _, cross_attention_maps, _ = model.infer_batch( + batch_cuda, + max_decoder_steps=440, + temperature=temperature, + topk=topk, + use_cfg=use_cfg, + cfg_scale=cfg_scale, + return_cross_attn_probs=True, + apply_attention_prior=apply_attention_prior, + prior_epsilon=attention_prior_epsilon, + lookahead_window_size=attention_prior_lookahead_window, + estimate_alignment_from_layers=estimate_alignment_from_layers, + apply_prior_to_layers=apply_prior_to_layers, + start_prior_after_n_audio_steps=start_prior_after_n_audio_steps + ) + et = time.time() print(f"Time taken for inference: {et-st}", predicted_audio.size()) for idx in range(predicted_audio.size(0)): + cross_attn_map_image = Image.fromarray(cross_attention_maps[idx]) + cross_attn_map_image.save(os.path.join(audio_dir, f"cross_attn_map_{item_idx}.png")) + predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]] audio_path = os.path.join(pred_audio_dir, f"predicted_audio_{item_idx}.wav") @@ -195,12 +245,25 @@ def main(): parser.add_argument('--temperature', type=float, default=0.6) parser.add_argument('--use_cfg', action='store_true') parser.add_argument('--cfg_scale', type=float, default=1.0) + parser.add_argument('--apply_attention_prior', action='store_true') + parser.add_argument('--attention_prior_epsilon', type=float, default=1e-3) + parser.add_argument('--attention_prior_lookahead_window', type=int, default=10) + parser.add_argument('--estimate_alignment_from_layers', type=str, default=None) + parser.add_argument('--apply_prior_to_layers', type=str, default=None) + parser.add_argument('--start_prior_after_n_audio_steps', type=int, default=10) parser.add_argument('--topk', type=int, default=80) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--num_repeats', type=int, default=1) parser.add_argument('--confidence_level', type=float, default=0.95) args = parser.parse_args() + estimate_alignment_from_layers = None + if args.estimate_alignment_from_layers is not None: + estimate_alignment_from_layers = [int(l.strip()) for l in args.estimate_alignment_from_layers.split(",")] + apply_prior_to_layers = None + if args.apply_prior_to_layers is not None: + apply_prior_to_layers = [int(l.strip()) for l in args.apply_prior_to_layers.split(",")] + if (args.hparams_files is not None) and (args.checkpoint_files is not None) and (args.hparams_files != "null"): hparam_files = args.hparams_files.split(",") checkpoint_files = args.checkpoint_files.split(",") @@ -209,18 +272,24 @@ def main(): assert len(hparam_files) == len(checkpoint_files), "Number of hparams files and checkpoint files should be the same." for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): run_inference( - hparams_file, - checkpoint_file, - args.datasets.split(","), - args.out_dir, - args.temperature, - args.topk, - args.codecmodel_path, - args.use_cfg, - args.cfg_scale, - args.batch_size, - args.num_repeats, - args.confidence_level, + hparams_file=hparams_file, + checkpoint_file=checkpoint_file, + datasets=args.datasets.split(","), + out_dir=args.out_dir, + temperature=args.temperature, + topk=args.topk, + codecmodel_path=args.codecmodel_path, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + batch_size=args.batch_size, + num_repeats=args.num_repeats, + apply_attention_prior=args.apply_attention_prior, + attention_prior_epsilon=args.attention_prior_epsilon, + attention_prior_lookahead_window=args.attention_prior_lookahead_window, + estimate_alignment_from_layers=estimate_alignment_from_layers, + apply_prior_to_layers=apply_prior_to_layers, + start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, + confidence_level=args.confidence_level, ) return else: @@ -264,15 +333,21 @@ def main(): run_inference( hparams_copy_path, checkpoint_copy_path, - args.datasets.split(","), - args.out_dir, - args.temperature, - args.topk, - args.codecmodel_path, - args.use_cfg, - args.cfg_scale, - args.batch_size, + datasets=args.datasets.split(","), + out_dir=args.out_dir, + temperature=args.temperature, + topk=args.topk, + codecmodel_path=args.codecmodel_path, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + batch_size=args.batch_size, num_repeats=args.num_repeats, + apply_attention_prior=args.apply_attention_prior, + attention_prior_epsilon=args.attention_prior_epsilon, + attention_prior_lookahead_window=args.attention_prior_lookahead_window, + estimate_alignment_from_layers=estimate_alignment_from_layers, + apply_prior_to_layers=apply_prior_to_layers, + start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, ) diff --git a/t5tts_inference.ipynb b/t5tts_inference.ipynb new file mode 100644 index 000000000000..f60833bd8186 --- /dev/null +++ b/t5tts_inference.ipynb @@ -0,0 +1,311 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "466ccdc5", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from nemo.collections.tts.models import T5TTS_Model\n", + "from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset, DatasetSample\n", + "from omegaconf.omegaconf import OmegaConf, open_dict\n", + "import torch\n", + "import os\n", + "import soundfile as sf\n", + "from IPython.display import display, Audio\n", + "import numpy as np\n", + "import os\n", + "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"" + ] + }, + { + "cell_type": "markdown", + "id": "1f5798ac", + "metadata": {}, + "source": [ + "### Checkpoint Paths" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04445f11", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "hparams_file = \"/datap/misc/ChallengingFinetuneLocalTraining/head1hparams.yaml\"\n", + "checkpoint_file = \"/datap/misc/ChallengingFinetuneLocalTraining/head1_epoch248.ckpt\"\n", + "codecmodel_path = \"/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo\"\n", + "\n", + "# Temp out dir for saving audios\n", + "out_dir = \"/datap/misc/t5tts_inference_notebook_samples\"\n", + "if not os.path.exists(out_dir):\n", + " os.makedirs(out_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "86bf2a16", + "metadata": {}, + "source": [ + "### Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87bf66f9", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "model_cfg = OmegaConf.load(hparams_file).cfg\n", + "\n", + "with open_dict(model_cfg):\n", + " model_cfg.codecmodel_path = codecmodel_path\n", + " if hasattr(model_cfg, 'text_tokenizer'):\n", + " # Backward compatibility for models trained with absolute paths in text_tokenizer\n", + " model_cfg.text_tokenizer.g2p.phoneme_dict = \"scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt\"\n", + " model_cfg.text_tokenizer.g2p.heteronyms = \"scripts/tts_dataset_files/heteronyms-052722\"\n", + " model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0\n", + " model_cfg.train_ds = None\n", + " model_cfg.validation_ds = None\n", + "\n", + "\n", + "model = T5TTS_Model(cfg=model_cfg)\n", + "print(\"Loading weights from checkpoint\")\n", + "ckpt = torch.load(checkpoint_file)\n", + "model.load_state_dict(ckpt['state_dict'])\n", + "print(\"Loaded weights.\")\n", + "\n", + "model.use_kv_cache_for_inference = True\n", + "\n", + "model.cuda()\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "id": "361b5711", + "metadata": {}, + "source": [ + "### Initialize Dataset class and helper functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "840a7271", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "test_dataset = T5TTSDataset(\n", + " dataset_meta={},\n", + " sample_rate=model_cfg.sample_rate,\n", + " min_duration=0.5,\n", + " max_duration=20,\n", + " codec_model_downsample_factor=model_cfg.codec_model_downsample_factor,\n", + " bos_id=model.bos_id,\n", + " eos_id=model.eos_id,\n", + " context_audio_bos_id=model.context_audio_bos_id,\n", + " context_audio_eos_id=model.context_audio_eos_id,\n", + " audio_bos_id=model.audio_bos_id,\n", + " audio_eos_id=model.audio_eos_id,\n", + " num_audio_codebooks=model_cfg.num_audio_codebooks,\n", + " prior_scaling_factor=None,\n", + " load_cached_codes_if_available=True,\n", + " dataset_type='test',\n", + " tokenizer_config=None,\n", + " load_16khz_audio=model.model_type == 'single_encoder_sv_tts',\n", + " use_text_conditioning_tokenizer=model.use_text_conditioning_encoder,\n", + " pad_context_text_to_max_duration=model.pad_context_text_to_max_duration,\n", + " context_duration_min=model.cfg.get('context_duration_min', 5.0),\n", + " context_duration_max=model.cfg.get('context_duration_max', 5.0),\n", + ")\n", + "test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test')\n", + "\n", + "\n", + "\n", + "def get_audio_duration(file_path):\n", + " with sf.SoundFile(file_path) as audio_file:\n", + " # Calculate the duration\n", + " duration = len(audio_file) / audio_file.samplerate\n", + " return duration\n", + "\n", + "def create_record(text, context_audio_filepath=None, context_text=None):\n", + " dummy_audio_fp = os.path.join(out_dir, \"dummy_audio.wav\")\n", + " dummy_audio = sf.write(dummy_audio_fp, np.zeros(22050 * 3), 22050) # 3 seconds of silence\n", + " record = {\n", + " 'audio_filepath' : dummy_audio_fp,\n", + " 'duration': 3.0,\n", + " 'text': text,\n", + " 'speaker': \"dummy\",\n", + " }\n", + " if context_text is not None:\n", + " assert context_audio_filepath is None\n", + " record['context_text'] = context_text\n", + " else:\n", + " assert context_audio_filepath is not None\n", + " record['context_audio_filepath'] = context_audio_filepath\n", + " record['context_audio_duration'] = get_audio_duration(context_audio_filepath)\n", + " \n", + " return record" + ] + }, + { + "cell_type": "markdown", + "id": "e9aa7a5a", + "metadata": {}, + "source": [ + "### Set transcript and context pairs to test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7374d3f", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Change sample text and prompt audio/text here\n", + "audio_base_dir = \"/\"\n", + "test_entries = [\n", + " create_record(\n", + " text=\"This is an example of a regular text without any challlenging texts like repeated words or numbers just to do a sanity check.\",\n", + " context_text=\"Speaker and Emotion: | Language:en Dataset:Riva Speaker:Lindy_WIZWIKI |\",\n", + " #context_audio_filepath=\"/datap/misc/LibriTTSfromNemo/LibriTTS/test-clean/7729/102255/7729_102255_000012_000001.wav\", # Supply either context_audio_filepath or context_text, not both\n", + " ),\n", + "]\n", + "\n", + "data_samples = []\n", + "for entry in test_entries:\n", + " dataset_sample = DatasetSample(\n", + " dataset_name=\"sample\",\n", + " manifest_entry=entry,\n", + " audio_dir=audio_base_dir,\n", + " feature_dir=audio_base_dir,\n", + " text=entry['text'],\n", + " speaker=None,\n", + " speaker_index=0,\n", + " tokenizer_names=[\"english_phoneme\"], # Change this for multilingual: \"english_phoneme\", \"spanish_phoneme\", \"english_chartokenizer\", \"german_chartokenizer\".. \n", + " )\n", + " data_samples.append(dataset_sample)\n", + " \n", + "test_dataset.data_samples = data_samples\n", + "\n", + "test_data_loader = torch.utils.data.DataLoader(\n", + " test_dataset,\n", + " batch_size=1,\n", + " collate_fn=test_dataset.collate_fn,\n", + " num_workers=0,\n", + " shuffle=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "aab9866b", + "metadata": {}, + "source": [ + "### Generate With Prior" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "745b2ef7", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "item_idx = 0\n", + "for bidx, batch in enumerate(test_data_loader):\n", + " print(\"Processing batch {} out of {}\".format(bidx, len(test_data_loader)))\n", + " model.t5_decoder.reset_cache(use_cache=True)\n", + " batch_cuda ={}\n", + " for key in batch:\n", + " if isinstance(batch[key], torch.Tensor):\n", + " batch_cuda[key] = batch[key].cuda()\n", + " else:\n", + " batch_cuda[key] = batch[key]\n", + " import time\n", + " st = time.time()\n", + " \n", + " for _ in range(1):\n", + " for apply_prior in [True, False]:\n", + " predicted_audio, predicted_audio_lens, _, _, cross_attn_np, all_heads_attn_np = model.infer_batch(\n", + " batch_cuda, \n", + " max_decoder_steps=430, \n", + " temperature=0.6, \n", + " topk=80, \n", + " use_cfg=True,\n", + " cfg_scale=2.5,\n", + " prior_epsilon=0.1,\n", + " lookahead_window_size=5,\n", + " return_cross_attn_probs=True,\n", + " estimate_alignment_from_layers=[4,6,7],\n", + " apply_attention_prior=apply_prior,\n", + " apply_prior_to_layers=[3,4,5,6,7,8,9,10],\n", + " compute_all_heads_attn_maps=True,\n", + " start_prior_after_n_audio_steps=10\n", + " )\n", + " print(\"generation time\", time.time() - st)\n", + " for idx in range(predicted_audio.size(0)):\n", + " predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()\n", + " predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]\n", + " audio_path = os.path.join(out_dir, f\"predicted_audio_{item_idx}.wav\")\n", + " sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)\n", + " print(test_entries[bidx]['text'])\n", + " print(\"Prior Used?\", apply_prior)\n", + " display(Audio(audio_path))\n", + " item_idx += 1\n", + " plt.imshow(cross_attn_np[idx])\n", + " plt.show()\n", + "# for hidx, head_cross_attn in enumerate(all_heads_attn_np[idx]):\n", + "# layer_num = hidx // model.cfg.t5_decoder.xa_n_heads\n", + "# head_num = hidx % model.cfg.t5_decoder.xa_n_heads\n", + "# print(\"item, layer, head\", idx, layer_num, head_num)\n", + "# plt.imshow(all_heads_attn_np[idx][hidx])\n", + "# plt.show()\n", + "\n", + " print(\"------------------------------------\")\n", + " print(\"------------------------------------\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 6f475c7c1489cc906ee531d35c3d0a6c1a0b4f33 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 25 Feb 2025 11:54:21 -0800 Subject: [PATCH 004/113] Bugfix in DPO Pareto ranking (#53) When doing Pareto ranking make sure to only compare indices that correspond to metrics. --- scripts/magpietts/dpo/create_preference_pairs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/magpietts/dpo/create_preference_pairs.py b/scripts/magpietts/dpo/create_preference_pairs.py index 99d212c6c496..7519100c2641 100644 --- a/scripts/magpietts/dpo/create_preference_pairs.py +++ b/scripts/magpietts/dpo/create_preference_pairs.py @@ -166,6 +166,8 @@ def pareto_rank(items): # A helper function to check if item A is dominated by item B # A: (cerA, ssimA), B: (cerB, ssimB) def is_dominated(A, B): + assert len(A) == 2 + assert len(B) == 2 return (B[0] <= A[0]) and (B[1] >= A[1]) and (B != A) # Equivalently, check at least one strict inequality: # (B[0] < A[0]) or (B[1] > A[1]) @@ -185,7 +187,7 @@ def is_dominated(A, B): dominated = False for j in range(len(remaining)): if i != j: - if is_dominated(remaining[i], remaining[j]): + if is_dominated(remaining[i][:2], remaining[j][:2]): dominated = True break if not dominated: From 59bb9cf13c9a86bebe9a2d48b9eaac00a0d9c645 Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Thu, 6 Mar 2025 11:45:03 -0800 Subject: [PATCH 005/113] Local Transformer and Binarized attention prior from alignment module (#52) Updated by Jason * local transformer training tested, prediction not tested Signed-off-by: Paarth Neekhara * local transformer updates Signed-off-by: Paarth Neekhara * local transformer inference working Signed-off-by: Paarth Neekhara * aligner module Signed-off-by: Paarth Neekhara * aligner module updates Signed-off-by: Paarth Neekhara * wip Signed-off-by: Paarth Neekhara * wip Signed-off-by: Paarth Neekhara * change aligner text input to encoder output Signed-off-by: Paarth Neekhara * obtain hard alignment from t5tts decoder Signed-off-by: Paarth Neekhara * log hard attention training Signed-off-by: Paarth Neekhara * binarization method, obtain_prior_from_cross_attn fix Signed-off-by: Paarth Neekhara * added configs for local transformer and alignment encoder Signed-off-by: Paarth Neekhara * added prior window decay factors Signed-off-by: Paarth Neekhara * more configs.. Signed-off-by: Paarth Neekhara * config was missing Signed-off-by: Paarth Neekhara * slight modification in alignment encoder computation, pass target audio embeddings (removing bos) Signed-off-by: Paarth Neekhara * some comments Signed-off-by: Paarth Neekhara * prior prob configurable Signed-off-by: Paarth Neekhara * update yamls Signed-off-by: Paarth Neekhara * refactor inference prior code Signed-off-by: Paarth Neekhara * set prior epsilon to 0 to avoid any attention scores on unintended parts Signed-off-by: Paarth Neekhara * make prior epsilon configurable in training Signed-off-by: Paarth Neekhara * added rtf metrics and notebook, infer and evaluate changes Signed-off-by: Paarth Neekhara * turn off alignment encoder training after 50k steps Signed-off-by: Paarth Neekhara --------- Signed-off-by: Paarth Neekhara --- examples/tts/conf/magpietts/magpietts_en.yaml | 28 +- .../magpietts/magpietts_inference_en.yaml | 25 + .../magpietts_inference_multilingual_v1.yaml | 25 + .../tts/conf/magpietts/magpietts_lhotse.yaml | 270 +++++++++ .../magpietts/magpietts_multilingual_v1.yaml | 26 + nemo/collections/tts/models/magpietts.py | 531 ++++++++++++++---- .../tts/modules/transformer_2501.py | 12 +- scripts/magpietts/evalset_config.py | 12 + scripts/magpietts/infer_and_evaluate.py | 44 +- t5tts_inference.ipynb | 116 ++-- 10 files changed, 924 insertions(+), 165 deletions(-) create mode 100644 examples/tts/conf/magpietts/magpietts_lhotse.yaml diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index e71d9ac5d261..c46ab208dfe9 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -29,7 +29,8 @@ model: load_cached_codes_if_available: true prior_scaling_factor: 0.5 prior_end_step: 12000 - prior_scaledown_start_step: 8000 + prior_scaledown_start_step: 8000 # Prior will always be on before this step. + indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. alignment_loss_scale: 0.0 embedding_dim: 768 codecmodel_path: ??? @@ -38,6 +39,30 @@ model: sample_rate: ${sample_rate} + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + use_local_transformer: false + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer @@ -125,6 +150,7 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true + prior_eps: 1e-8 optim: _target_: torch.optim.Adam diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_inference_en.yaml index eec62db5547d..b1b042f8663d 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_en.yaml @@ -46,6 +46,30 @@ model: sample_rate: ${sample_rate} + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + use_local_transformer: false + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer @@ -118,6 +142,7 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true + prior_eps: 1e-8 optim: _target_: torch.optim.Adam diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml index 9831131092ed..74ca87fea48c 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml @@ -46,6 +46,30 @@ model: sample_rate: ${sample_rate} + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + use_local_transformer: false + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer @@ -166,6 +190,7 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true + prior_eps: 1e-8 optim: _target_: torch.optim.Adam diff --git a/examples/tts/conf/magpietts/magpietts_lhotse.yaml b/examples/tts/conf/magpietts/magpietts_lhotse.yaml new file mode 100644 index 000000000000..26d58cac63e3 --- /dev/null +++ b/examples/tts/conf/magpietts/magpietts_lhotse.yaml @@ -0,0 +1,270 @@ +name: T5TTS + +max_steps: ??? +limit_val_batches: ??? +# Adjust batch size based on GPU memory +batch_size: 16 +micro_batch_size: 16 +batch_duration: ??? +eval_batch_size: ??? + +# Dataset metadata for each manifest +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +train_ds_meta: ??? +val_ds_meta: ??? + +# Modify these values based on your sample rate +sample_rate: 22050 + +phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" +heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" +model: + use_lhotse: true + model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer + use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. + transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. + context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts + context_duration_min: 3.0 + context_duration_max: 8.0 + speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise + num_audio_codebooks: 8 + num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids + codec_model_downsample_factor: 1024 + load_cached_codes_if_available: true + prior_scaling_factor: 0.5 + prior_end_step: 12000 + prior_scaledown_start_step: 8000 + indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. + alignment_loss_scale: 0.0 + embedding_dim: 768 + codecmodel_path: ??? + max_steps: ${max_steps} + + sample_rate: ${sample_rate} + + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + use_local_transformer: false + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + + text_tokenizers: + english_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: false + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: ${phoneme_dict_path} + heteronyms: ${heteronyms_path} + phoneme_probability: 0.8 + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + + train_ds: + use_lhotse: ${model.use_lhotse} + dataset: + input_cfg: + - type: lhotse_shar + shar_path: /cluster_data/TTS/tts_lhotse_datasets/hifitts_v0/ + weight: 1.0 + tags: + lang: en + s2s: True + tokenizer_names: ["english_phoneme"] + + - type: lhotse_shar + shar_path: /cluster_data/TTS/tts_lhotse_datasets/libri100/ + weight: 1.0 + tags: + lang: en + s2s: True + tokenizer_names: ["english_phoneme"] + + - type: lhotse_shar + shar_path: /cluster_data/TTS/tts_lhotse_datasets/rivaLindyRodney/ + weight: 1.0 + tags: + lang: en + s2s: True + tokenizer_names: ["english_phoneme"] + + - type: lhotse_shar + shar_path: /cluster_data/TTS/tts_lhotse_datasets/libri360/ + weight: 1.0 + tags: + lang: en + s2s: True + tokenizer_names: ["english_phoneme"] + + global_batch_size: ${batch_size} + micro_batch_size: ${micro_batch_size} + batch_size: null + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 512 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + # ASR configs + sample_rate: ${model.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + # tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + use_bucketing: true + use_lhotse: ${model.use_lhotse} + text_field : "text" + seed: 'trng' + batch_duration : ${batch_duration} # 0 + quadratic_duration : 20 + num_buckets : 31 + buffer_size : 10000 + shuffle_buffer_size : 10000 + num_cuts_for_bins_estimate: 10000 + duration_bins: [3.155,3.76,4.27,4.74,5.1935,5.64,6.096,6.588,7.14,7.81,8.28,8.664,9.072,9.57,10.14,10.7335,11.3735,12.09,12.78,13.41,14.01,14.62,15.253375,15.96875,16.71,17.45,18.1335,18.7735,19.4,20.0] + # bucket_duration_bins: [3.155,3.76,4.27,4.74,5.1935,5.64,6.096,6.588,7.14,7.81,8.28,8.664,9.072,9.57,10.14,10.7335,11.3735,12.09,12.78,13.41,14.01,14.62,15.253375,15.96875,16.71,17.45,18.1335,18.7735,19.4,20.0] + + validation_ds: + use_lhotse: ${model.use_lhotse} + dataset: + input_cfg: + - type: lhotse_shar + shar_path: /cluster_data/TTS/tts_lhotse_datasets/LibriTTS_dev_clean/ + weight: 1.0 + tags: + lang: en + s2s: True + tokenizer_names: ["english_phoneme"] + + global_batch_size: ${batch_size} + micro_batch_size: ${micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + drop_last: False + use_bucketing: false + is_tarred: false + batch_size: ${eval_batch_size} + + t5_encoder: + n_layers: 6 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: False + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise + n_layers: 3 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: false + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + t5_decoder: + n_layers: 12 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: true + xa_d_memory: 768 + xa_n_heads: 12 + is_causal: true + apply_norm_to_cond: true + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + prior_eps: 1e-8 + + optim: + _target_: torch.optim.Adam + lr: 2e-4 + betas: [0.8, 0.99] + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: 32 + max_epochs: -1 + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 10 + val_check_interval: 500 + limit_train_batches: ${trainer.val_check_interval} + # check_val_every_n_epoch: 10 + benchmark: false + max_steps: ${max_steps} + limit_val_batches: ${limit_val_batches} + use_distributed_sampler: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + resume_if_exists: true + resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index 82d9ec6f3890..9f3a4d14f03d 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -30,6 +30,7 @@ model: prior_scaling_factor: 0.5 prior_end_step: 12000 prior_scaledown_start_step: 8000 + indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. alignment_loss_scale: 0.0 embedding_dim: 768 codecmodel_path: ??? @@ -38,6 +39,30 @@ model: sample_rate: ${sample_rate} + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + use_local_transformer: false + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer @@ -173,6 +198,7 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true + prior_eps: 1e-8 optim: _target_: torch.optim.Adam diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 4c9c5d84ac3f..ebcf98d984e8 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -16,6 +16,7 @@ import os import random import string +import time from typing import List import librosa @@ -30,13 +31,16 @@ from torch.utils.data import get_worker_info from transformers import AutoTokenizer, T5Tokenizer + import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer +from nemo.collections.tts.data.text_to_speech_dataset_lhotse import build_lhotse_dataloader, T5TTSLhotseDataset from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 -from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths, plot_alignment_to_numpy +from nemo.collections.tts.modules.aligner import AlignmentEncoder +from nemo.collections.tts.parts.utils.helpers import binarize_attention_parallel, get_mask_from_lengths, plot_alignment_to_numpy from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo @@ -163,8 +167,36 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) - self.final_proj = nn.Linear(cfg.decoder.d_model, cfg.num_audio_codebooks * cfg.num_audio_tokens_per_codebook) + if cfg.get('use_local_transformer', False): + local_transformer_hidden_dim = cfg.get('local_transformer_hidden_dim', 256) + if local_transformer_hidden_dim != cfg.t5_decoder.d_model: + self.local_transformer_in_projection = nn.Linear(cfg.t5_decoder.d_model, local_transformer_hidden_dim) + else: + self.local_transformer_in_projection = nn.Identity() + self.local_transformer = t5tts_transformer.Transformer( + n_layers=self.cfg.get('local_transformer_n_layers', 2), + d_model=local_transformer_hidden_dim, + d_ffn=local_transformer_hidden_dim*4, + sa_n_heads=self.cfg.get('local_transformer_n_heads', 1), + kernel_size=1, + is_causal=True, + max_length_causal_mask=cfg.num_audio_codebooks+2, + use_learnable_pos_emb=True, + ) + local_transformer_out_projections = [] + for _ in range(cfg.num_audio_codebooks): + # Have a separate projection layer for each codebook, to distinguish between them + local_transformer_out_projections.append(nn.Linear(local_transformer_hidden_dim, cfg.num_audio_tokens_per_codebook)) + self.local_transformer_out_projections = nn.ModuleList(local_transformer_out_projections) + + if cfg.get('use_alignment_encoder', False): + self.alignment_encoder = AlignmentEncoder( + n_mel_channels=cfg.embedding_dim, + n_text_channels=cfg.embedding_dim, + dist_type="cosine", + temperature=15.0, + ) codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False) # del codec discriminator to free memory @@ -210,8 +242,11 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0) + alignment_encoder_loss_scale = cfg.get('alignment_encoder_loss_scale', 0.0) if alignment_loss_scale > 0.0: self.alignment_loss = ForwardSumLoss(loss_scale=alignment_loss_scale) + if alignment_encoder_loss_scale > 0.0: + self.alignment_encoder_loss = ForwardSumLoss(loss_scale=alignment_encoder_loss_scale) def freeze_model(self, model): for param in model.parameters(): @@ -319,6 +354,39 @@ def get_speaker_embeddings(self, audio_16khz, audio_len_16khz): ) return speaker_embeddings + def compute_local_transformer_logits(self, dec_out, audio_codes_target): + """ + Loss from the autoregrssive codebook predictor (used per frame) + """ + # dec_out: (B, T', E) + # audio_codes: (B, C, T') + dec_out_all = dec_out.reshape(-1, dec_out.size(-1)) # (B*T', E) + local_transformer_input = [dec_out_all] + for codebook_num in range(audio_codes_target.size(1)): + codes = audio_codes_target[:, codebook_num] # (B, T') + codes = codes.reshape(-1) # (B*T',) + codebook_embedding = self.audio_embeddings[codebook_num](codes) # (B*T', E) + local_transformer_input.append(codebook_embedding) + + local_transformer_input = torch.stack(local_transformer_input, dim=1) # (B*T', C+1, E) + local_transformer_input = self.local_transformer_in_projection(local_transformer_input) # (B*T', C+1, 128) + _mask = torch.ones( local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device) + local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B*T', C+1, E) + local_transformer_output = local_transformer_output[:, :-1, :] # (B*T', C, E) + all_code_logits = [] + for codebook_num in range(audio_codes_target.size(1)): + # Using a separate projection layer for each codebook (to distinguish between them) + # Checked the time - this loop is not taking much time (compared to the local transformer forward pass) + codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, codebook_num, :]) # (B*T', num_audio_tokens_per_codebook) + all_code_logits.append(codebook_logits) + all_code_logits = torch.cat(all_code_logits, dim=1) # (B*T', num_codebooks * num_audio_tokens_per_codebook) + + all_code_logits = all_code_logits.view( + audio_codes_target.size(0), audio_codes_target.size(2), -1 + ) # (B, T', C * num_audio_tokens_per_codebook) + + return all_code_logits + def compute_loss(self, logits, audio_codes, audio_codes_lens): # logits: (B, T', num_codebooks * num_tokens_per_codebook) # audio_codes: (B, C, T') @@ -354,7 +422,7 @@ def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prio ) attn_probabilities = decoder_out['attn_probabilities'] all_code_logits = self.final_proj(decoder_out['output']) # (B, T', num_codebooks * num_tokens_per_codebook) - return all_code_logits, attn_probabilities + return all_code_logits, attn_probabilities, decoder_out['output'] def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): # all_code_logits: (B, T', num_codebooks * num_tokens_per_codebook) @@ -375,6 +443,50 @@ def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): return all_preds + def sample_codes_from_local_transformer(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0): + # dec_output: (B, E) + # import ipdb; ipdb.set_trace() + self.local_transformer.reset_cache(use_cache=True) + dec_output = dec_output.unsqueeze(1) # (B, 1, E) + local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) + all_preds = [] + for codebook_num in range(self.cfg.num_audio_codebooks): + _mask = torch.ones( local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device) + local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, T, 128) + codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, -1, :]) # (B, num_audio_tokens_per_codebook) + if use_cfg: + actual_batch_size = codebook_logits.size(0) // 2 + conditional_logits = codebook_logits[:actual_batch_size] + unconditional_logits = codebook_logits[actual_batch_size:] + cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits + codebook_logits[:actual_batch_size] = cfg_logits + + for item_idx in unfinished_items: + codebook_logits[item_idx, self.audio_eos_id] = float('-inf') + for item_idx in finished_items: + codebook_logits[item_idx, :] = float('-inf') + codebook_logits[item_idx, self.audio_eos_id] = 0.0 + + codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) + indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(-1) # (B, num_tokens_per_codebook) + codebook_logits_rescored = codebook_logits.clone() + codebook_logits_rescored[indices_to_remove] = float('-inf') + codebook_probs = torch.softmax(codebook_logits / temperature, dim=-1) # (B, num_tokens_per_codebook) + codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) + if use_cfg: + codebook_preds[actual_batch_size:] = codebook_preds[:actual_batch_size] + all_preds.append(codebook_preds) + next_local_transformer_input = self.audio_embeddings[codebook_num](codebook_preds.squeeze(-1)).unsqueeze(1) # (B, 1, 128) + next_local_transformer_input = self.local_transformer_in_projection(next_local_transformer_input) # (B, 1, 128) + local_transformer_input = torch.cat([local_transformer_input, next_local_transformer_input], dim=1) # (B, T+1, 128) + + all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks) + if use_cfg: + all_preds = all_preds[:actual_batch_size] + + return all_preds + + def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={}): # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep all_preds = [] @@ -468,13 +580,13 @@ def scale_prior(self, prior, global_step): if global_step < prior_scaledown_start_step: return prior elif global_step >= prior_end_step: - if self.cfg.get('prior_always_applied', False): - # Added this so that model always knows how to work with and without the prior - if random.random() < 0.5: - return prior - else: - return None - return None + indefinite_prior_prob = self.cfg.get('indefinite_prior_prob', 0.0) + if random.random() < indefinite_prior_prob: + print("Using Prior") + return prior + else: + print("Not using Prior") + return None else: with torch.no_grad(): # Interpolate between all ones and the prior @@ -619,19 +731,81 @@ def prepare_context_tensors(self, batch): attn_prior = [attn_prior if layer_idx in ctc_prior_layer_ids else None for layer_idx in range(self.cfg.t5_decoder.n_layers) ] return { + 'beta_binomial_attn_prior': batch.get('align_prior_matrix', None), + 'text_encoder_out': text_encoder_out, 'cond': cond, 'cond_mask': cond_mask, 'attn_prior': attn_prior, + 'prior_used': _attn_prior is not None, 'multi_encoder_mapping': multi_encoder_mapping, 'additional_decoder_input': additional_decoder_input, 'addtional_decoder_mask': addtional_decoder_mask, 'dec_context_size': dec_context_size, 'text': text, + 'text_embedded': text_embedded, + 'text_mask': text_mask, 'text_lens': text_lens, 'context_audio_codes': context_audio_codes, 'context_audio_codes_lens': context_audio_codes_lens, } + def replace_beta_binomial_prior_with_binarized(self, attn_prior, aligner_attn_hard): + # aligner_attn_hard B, audio_timesteps, text_timesteps + if self.model_type == 'multi_encoder_context_tts': + text_attn_prior = attn_prior[0] + else: + text_attn_prior = attn_prior + + assert text_attn_prior is not None, "Prior is None" + + if isinstance(text_attn_prior, list): + # Layer wise prior + prior_updated = False + for idx, prior in enumerate(text_attn_prior): + if prior is not None: + text_attn_prior[idx][:,-aligner_attn_hard.size(1):,:] = aligner_attn_hard + prior_updated = True + assert prior_updated, "Did not find any prior to update" + else: + # Same prior for all layers + text_attn_prior[:,-aligner_attn_hard.size(1):,:] = aligner_attn_hard + + if self.model_type == 'multi_encoder_context_tts': + attn_prior[0] = text_attn_prior + else: + attn_prior = text_attn_prior + + return attn_prior + + def get_binarized_prior_matrix(self, aligner_attn_soft, audio_lens, text_lens): + # aligner_attn_soft B, 1, audio_timesteps, text_timesteps + if self.cfg.get('binarize_attn_method', 'argmax') == 'nemo_binarize': + binarize_repeat_audio_factor = self.cfg.get('binarize_repeat_audio_factor', 2) + aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave(binarize_repeat_audio_factor, dim=2) # B, 1, 2*audio_timesteps, text_timesteps + aligner_attn_hard = binarize_attention_parallel(aligner_attn_soft_repeated, text_lens, audio_lens*binarize_repeat_audio_factor).squeeze(1) # B, 2*audio_timesteps, text_timesteps + aligner_attn_hard = aligner_attn_hard[:, ::2, :] # B, audio_timesteps, text_timesteps + else: + print("Binaraizing attention using argmax") + aligner_attn_hard = torch.argmax(aligner_attn_soft.squeeze(1), dim=-1) + aligner_attn_hard = torch.nn.functional.one_hot(aligner_attn_hard, num_classes=aligner_attn_soft.size(-1)).float() + + prior_future_decay = self.cfg.get('prior_future_decay', 1.0) + prior_past_decay = self.cfg.get('prior_past_decay', 1.0) + binarized_prior_epsilon = self.cfg.get('binarized_prior_epsilon', 0.0) + aligner_attn_hard_wider = aligner_attn_hard + binarized_prior_epsilon + + for future_timestep in range(self.cfg.get('prior_future_context', 1)): + decay_factor = prior_future_decay ** (future_timestep + 1) + aligner_attn_hard_wider[:,:,future_timestep+1:] += decay_factor * aligner_attn_hard[:,:,:-(future_timestep+1)] + + for past_timestep in range(self.cfg.get('prior_past_context', 1)): + decay_factor = prior_past_decay ** (past_timestep + 1) + aligner_attn_hard_wider[:,:,:-past_timestep-1] += decay_factor * aligner_attn_hard[:,:,past_timestep+1:] + + aligner_attn_hard_wider = torch.clamp(aligner_attn_hard_wider, 0.0, 1.0) + + return aligner_attn_hard_wider + def prepare_dummy_cond_for_cfg(self, cond, cond_mask, additional_decoder_input, additional_dec_mask): dummy_additional_decoder_input = None dummy_additional_dec_mask = None @@ -674,6 +848,8 @@ def process_batch(self, batch, mode="train"): audio_codes_input = audio_codes[:, :, :-1] # B, C, T' audio_codes_target = audio_codes[:, :, 1:] audio_codes_lens_input = audio_codes_lens_target = audio_codes_lens - 1 + audio_codes_embedded_all = self.embed_audio_tokens(audio_codes) # (B, T, E) # Computing this to be use in the alignment encoder + audio_codes_embedded = audio_codes_embedded_all[:, :-1, :] # (B, T', E) Input to the decoder audio_codes_mask = get_mask_from_lengths(audio_codes_lens_input) use_cfg = ( @@ -718,8 +894,8 @@ def process_batch(self, batch, mode="train"): ) # timestep_mask is True for timesteps to be kept audio_codes_input = audio_codes_input * dec_dropout_mask + random_audio_tokens * (~dec_dropout_mask) + audio_codes_embedded = self.embed_audio_tokens(audio_codes_input) # (B, T', E) - audio_codes_embedded = self.embed_audio_tokens(audio_codes_input) # (B, T', E) if context_tensors['additional_decoder_input'] is not None: dec_input_embedded = torch.cat([additional_decoder_input, audio_codes_embedded], dim=1) dec_input_mask = torch.cat([additional_decoder_mask, audio_codes_mask], dim=1) @@ -727,7 +903,44 @@ def process_batch(self, batch, mode="train"): dec_input_embedded = audio_codes_embedded dec_input_mask = audio_codes_mask - logits, attn_info = self.forward( + aligner_encoder_loss = None + aligner_attn_soft = None + aligner_attn_hard = None + if self.cfg.get('use_alignment_encoder', False) and not disable_alignment_loss: + aligner_prior = None + if self.cfg.get('use_prior_for_aligner', False): + aligner_prior = context_tensors['beta_binomial_attn_prior'] + # Passing target audio embeddings to the alignment encoder + if self.global_step < self.cfg.get('aligner_encoder_train_steps', float('inf')): + aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder( + queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' + keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T + mask=~context_tensors['text_mask'].unsqueeze(-1), + attn_prior=aligner_prior + ) + + aligner_encoder_loss = self.alignment_encoder_loss( + attn_logprob=aligner_attn_logprobs, in_lens=context_tensors['text_lens'], out_lens=audio_codes_lens_input + ) + else: + with torch.no_grad(): + # Just get the attention matrix without computing the loss or gradients + aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder( + queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' + keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T + mask=~context_tensors['text_mask'].unsqueeze(-1), + attn_prior=aligner_prior + ) + + with torch.no_grad(): + aligner_attn_hard = self.get_binarized_prior_matrix( + aligner_attn_soft, audio_codes_lens_input, context_tensors['text_lens'] + ) + if (self.global_step > self.cfg.get('binarize_prior_after_step', 0)) and context_tensors['prior_used']: + print("Updating Prior") + attn_prior = self.replace_beta_binomial_prior_with_binarized(attn_prior, aligner_attn_hard) + + logits, attn_info, dec_out = self.forward( dec_input_embedded=dec_input_embedded, dec_input_mask=dec_input_mask, cond=cond, @@ -736,10 +949,12 @@ def process_batch(self, batch, mode="train"): multi_encoder_mapping=context_tensors['multi_encoder_mapping'], ) # logits: (B, T', num_codebooks * num_tokens_per_codebook) + # dec_out: (B, T', E) dec_context_size = context_tensors['dec_context_size'] logits = logits[:, dec_context_size:, :] # Remove the context audio embeddings from the logits codebook_loss, loss_mask = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target) + codebook_loss_scale = self.cfg.get('codebook_loss_scale', 1.0) alignment_loss = None if self.cfg.alignment_loss_scale > 0.0 and not disable_alignment_loss: text_lens = context_tensors['text_lens'] @@ -748,17 +963,31 @@ def process_batch(self, batch, mode="train"): alignment_loss = self.compute_alignment_loss( cross_attention_scores, text_lens, audio_codes_lens_target, dec_context_size ) - loss = codebook_loss + alignment_loss + loss = codebook_loss_scale * codebook_loss + alignment_loss else: - loss = codebook_loss + loss = codebook_loss_scale * codebook_loss + + local_transformer_loss = None + local_transformer_logits = None + if self.cfg.get('use_local_transformer', False): + local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_target) + local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target, audio_codes_lens_target) + local_transformer_loss_scale = self.cfg.get('local_transformer_loss_scale', 1.0) + loss = loss + local_transformer_loss_scale * local_transformer_loss + + if aligner_encoder_loss is not None: + loss = loss + aligner_encoder_loss return { 'logits': logits, 'attn_info': attn_info, 'loss': loss, 'codebook_loss': codebook_loss, + 'local_transformer_loss' : local_transformer_loss, + 'local_transformer_logits' : local_transformer_logits, 'loss_mask': loss_mask, 'alignment_loss': alignment_loss, + 'aligner_encoder_loss': aligner_encoder_loss, 'audio_codes_target': audio_codes_target, 'audio_codes_lens_target': audio_codes_lens_target, 'text': context_tensors['text'], @@ -766,6 +995,8 @@ def process_batch(self, batch, mode="train"): 'context_audio_codes': context_tensors['context_audio_codes'], 'context_audio_codes_lens': context_tensors['context_audio_codes_lens'], 'dec_context_size': dec_context_size, + 'aligner_attn_soft': aligner_attn_soft, + 'aligner_attn_hard': aligner_attn_hard, } def training_step(self, batch, batch_idx): @@ -780,6 +1011,9 @@ def training_step(self, batch, batch_idx): if alignment_loss is not None: self.log('train_alignment_loss', alignment_loss, prog_bar=True, sync_dist=True) self.log('train_loss', loss, prog_bar=True, sync_dist=True) + local_transformer_loss = batch_output['local_transformer_loss'] + if local_transformer_loss is not None: + self.log('train_local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True) return loss @@ -788,6 +1022,7 @@ def validation_step(self, batch, batch_idx): loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] alignment_loss = batch_output['alignment_loss'] + aligner_encoder_loss = batch_output['aligner_encoder_loss'] logits = batch_output['logits'] audio_codes_target = batch_output['audio_codes_target'] audio_codes_lens_target = batch_output['audio_codes_lens_target'] @@ -798,6 +1033,8 @@ def validation_step(self, batch, batch_idx): dec_context_size = batch_output['dec_context_size'] if alignment_loss is None: alignment_loss = torch.tensor(0.0, device=loss.device) + if aligner_encoder_loss is None: + aligner_encoder_loss = torch.tensor(0.0, device=loss.device) if batch_idx == 0 and self.global_rank == 0: self.log_train_val_example( @@ -821,10 +1058,29 @@ def validation_step(self, batch, batch_idx): cross_attention_probs = [ attn_info[layer_idx]['cross_attn_probabilities'][0] ] self.log_attention_probs(cross_attention_probs, audio_codes_lens_target, text_lens, prefix=f"val_layer_{layer_idx}_", dec_context_size=dec_context_size) + if batch_output['aligner_attn_soft'] is not None: + self.log_attention_probs( + [batch_output['aligner_attn_soft']], + audio_codes_lens_target, + text_lens, + prefix=f"val_aligner_encoder_attn_", + ) + + if batch_output['aligner_attn_hard'] is not None: + self.log_attention_probs( + [batch_output['aligner_attn_hard'].unsqueeze(1)], + audio_codes_lens_target, + text_lens, + prefix=f"val_aligner_encoder_attn_hard_", + ) + + local_transformer_loss = batch_output['local_transformer_loss'] val_output = { 'val_loss': loss, 'val_codebook_loss': codebook_loss, 'val_alignment_loss': alignment_loss, + 'val_local_transformer_loss': local_transformer_loss, + 'val_aligner_encoder_loss': aligner_encoder_loss, } self.validation_step_outputs.append(val_output) @@ -849,6 +1105,96 @@ def get_cross_attention_scores(self, attn_probs, filter_layers=None): last_audio_timestep_scores = mean_cross_attn_scores[:, -1, :] # B, text_timesteps return last_audio_timestep_scores, all_heads_cross_attn_scores + def get_most_attended_text_timestep(self, alignment_attention_scores, last_attended_timesteps, + text_lens, lookahead_window_size, attended_timestep_counter, batch_size): + """ + Returns the most attended timestep for each batch item + """ + text_time_step_attended = [] + for bidx in range(batch_size): + last_attended_timestep = last_attended_timesteps[-1][bidx] + if attended_timestep_counter[bidx].get(last_attended_timestep, 0) >= 8: + # This is probably an attention sink! Move to the next timestep + last_attended_timestep += 1 + window_size = lookahead_window_size + window_end = min(last_attended_timestep + window_size, text_lens[bidx] - 3) # Ignore the last 3 timesteps + item_attention_scores = alignment_attention_scores[bidx,last_attended_timestep:window_end] + if item_attention_scores.size(0) == 0: + # This means the sentence has ended + attended_timestep = text_lens[bidx] - 1 + else: + attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep + text_time_step_attended.append(attended_timestep) + attended_timestep_counter[bidx][attended_timestep] = attended_timestep_counter[bidx].get(attended_timestep, 0) + 1 + return text_time_step_attended, attended_timestep_counter + + def construct_inference_prior(self, prior_epsilon, cross_attention_scores, + text_lens, text_time_step_attended, attended_timestep_counter, + unfinished_texts, finished_texts_counter, end_indices, batch_size): + # Attn prior for the next timestep + _attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + prior_epsilon + _attn_prior = _attn_prior.to(cross_attention_scores.device) + for bidx in range(cross_attention_scores.shape[0]): + if bidx < batch_size: + _text_len = text_lens[bidx] + if text_lens[bidx] <= 5: + # Very short sentences, No Prior + _attn_prior[bidx, 0, :] = 1.0 + else: + # _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-2)] = 0.1 # Slight exposure to history for better pronounciation. Not very important. + _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-1)] = 0.2 # Slight exposure to history for better pronounciation. Not very important. + _attn_prior[bidx, 0, text_time_step_attended[bidx]] = 0.8 # Slightly bias to continue moving forward. Not very important. + _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+1, _text_len - 1) ] = 1.0 + _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+2, _text_len - 1) ] = 0.8 + + # Penalize timesteps that have been attended to more than 10 times + for _timestep in attended_timestep_counter[bidx]: + if attended_timestep_counter[bidx][_timestep] >= 10: + # This means the timestep has been attended to more than 10 times (To avoid getting stuck) + _attn_prior[bidx, 0, _timestep] = prior_epsilon + + unfinished_texts[bidx] = False + if text_time_step_attended[bidx] < text_lens[bidx] - 3: + # This means the sentence has not ended + if bidx not in end_indices: + unfinished_texts[bidx] = True + + if text_time_step_attended[bidx] >= text_lens[bidx] - 5 or bidx in end_indices: + if bidx not in finished_texts_counter: + finished_texts_counter[bidx] = 0 + + for bidx in finished_texts_counter: + finished_texts_counter[bidx] += 1 + if finished_texts_counter[bidx] > 10: + # This means we have been within the text EOS window for atleast 10 timesteps + # We should allow EOS to be predicted now. + unfinished_texts[bidx] = False + + return _attn_prior, unfinished_texts, finished_texts_counter + + def get_inference_attention_plots(self, cross_attention_scores_all_timesteps, all_heads_cross_attn_scores_all_timesteps, text_lens, predicted_codes_lens, batch_size, compute_all_heads_attn_maps): + cross_attention_scores_all_timesteps = torch.stack(cross_attention_scores_all_timesteps, dim=2) # B, text_timesteps, T' + headwise_cross_attention_scores_all_timesteps = [] + for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): + head_cross_attention_all_timesteps = torch.stack([x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2) # B, text_timesteps, T' + headwise_cross_attention_scores_all_timesteps.append(head_cross_attention_all_timesteps) + + cross_attention_maps = [] + headwise_cross_attention_maps = [] + for bidx in range(batch_size): + item_cross_attention_scores = cross_attention_scores_all_timesteps[bidx,:text_lens[bidx],:predicted_codes_lens[bidx]] + cross_attn_np = plot_alignment_to_numpy(item_cross_attention_scores.cpu().numpy()) + cross_attention_maps.append(cross_attn_np) + item_all_head_cross_attn_maps = [] + if compute_all_heads_attn_maps: + for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): + item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][bidx,:text_lens[bidx],:predicted_codes_lens[bidx]] + headwise_cross_attn_np = plot_alignment_to_numpy(item_headwise_cross_attention_scores.cpu().numpy()) + item_all_head_cross_attn_maps.append(headwise_cross_attn_np) + headwise_cross_attention_maps.append(item_all_head_cross_attn_maps) + + return cross_attention_maps, headwise_cross_attention_maps + def infer_batch( self, batch, @@ -865,8 +1211,10 @@ def infer_batch( apply_prior_to_layers=None, start_prior_after_n_audio_steps=10, compute_all_heads_attn_maps=False, + use_local_transformer_for_inference=False, ): with torch.no_grad(): + start_time = time.time() self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference) context_tensors = self.prepare_context_tensors(batch) @@ -898,7 +1246,10 @@ def infer_batch( finished_texts_counter = {} attended_timestep_counter = [{} for _ in range(text.size(0))] last_attended_timesteps = [[1 for _ in range(text.size(0))]] # Maintain a list of attended timesteps as we predict audio for each batch item + time_to_first_prediction = 0.0 for idx in range(max_decoder_steps): + if idx == 1: + time_to_first_prediction = time.time() - start_time if idx % 20 == 0: print(f"Decoding timestep {idx}") audio_codes_embedded = self.embed_audio_tokens(audio_codes_input) @@ -922,9 +1273,7 @@ def infer_batch( attn_prior = [attn_prior, None] if use_cfg: - # import ipdb; ipdb.set_trace() batch_size = audio_codes_embedded.size(0) - # Combine conditional and unconditional inputs into one batch if isinstance(context_tensors['cond'], list): cfg_cond = [ torch.cat([cond_item, dummy_cond_item], dim=0) @@ -949,7 +1298,7 @@ def infer_batch( dummy_addition_dec_mask ) - combined_logits, attn_probs = self.forward( + combined_logits, attn_probs, dec_out = self.forward( dec_input_embedded=cfg_audio_codes_embedded, dec_input_mask=cfg_audio_codes_mask, cond=cfg_cond, @@ -963,7 +1312,7 @@ def infer_batch( all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits else: batch_size = audio_codes_embedded.size(0) - all_code_logits, attn_probs = self.forward( + all_code_logits, attn_probs, dec_out = self.forward( dec_input_embedded=_audio_codes_embedded, dec_input_mask=_audio_codes_mask, cond=context_tensors['cond'], @@ -977,76 +1326,50 @@ def infer_batch( alignment_attention_scores = cross_attention_scores if estimate_alignment_from_layers is not None: alignment_attention_scores, _ = self.get_cross_attention_scores(attn_probs, filter_layers=estimate_alignment_from_layers) # B, text_timesteps - text_time_step_attended = [] - for bidx in range(batch_size): - last_attended_timestep = last_attended_timesteps[-1][bidx] - if attended_timestep_counter[bidx].get(last_attended_timestep, 0) >= 8: - # This is probably an attention sink! Move to the next timestep - last_attended_timestep += 1 - window_size = lookahead_window_size - window_end = min(last_attended_timestep + window_size, context_tensors['text_lens'][bidx] - 3) # Ignore the last 3 timesteps - item_attention_scores = alignment_attention_scores[bidx,last_attended_timestep:window_end] - if item_attention_scores.size(0) == 0: - # This means the sentence has ended - attended_timestep = context_tensors['text_lens'][bidx] - 1 - else: - attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep - text_time_step_attended.append(attended_timestep) - attended_timestep_counter[bidx][attended_timestep] = attended_timestep_counter[bidx].get(attended_timestep, 0) + 1 - last_attended_timesteps.append(text_time_step_attended) cross_attention_scores_all_timesteps.append(cross_attention_scores) all_heads_cross_attn_scores_all_timesteps.append(all_heads_cross_attn_scores) - # if idx % 20 == 0: - # print("At timesteps", idx, text_time_step_attended, context_tensors['text_lens']) if apply_attention_prior and idx >= start_prior_after_n_audio_steps: - eps = prior_epsilon - # Attn prior for the next timestep - _attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + eps - _attn_prior = _attn_prior.to(cross_attention_scores.device) - for bidx in range(cross_attention_scores.shape[0]): - if bidx < batch_size: - _text_len = context_tensors['text_lens'][bidx] - if context_tensors['text_lens'][bidx] <= 5: - # Very short sentences, No Prior - _attn_prior[bidx, 0, :] = 1.0 - else: - # _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-2)] = 0.1 # Slight exposure to history for better pronounciation. Not very important. - _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-1)] = 0.2 # Slight exposure to history for better pronounciation. Not very important. - _attn_prior[bidx, 0, text_time_step_attended[bidx]] = 0.8 # Slightly bias to continue moving forward. Not very important. - _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+1, _text_len - 1) ] = 1.0 - _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+2, _text_len - 1) ] = 0.8 - - # Penalize timesteps that have been attended to more than 10 times - for _timestep in attended_timestep_counter[bidx]: - if attended_timestep_counter[bidx][_timestep] >= 10: - # This means the timestep has been attended to more than 10 times (To avoid getting stuck) - _attn_prior[bidx, 0, _timestep] = eps - - unfinished_texts[bidx] = False - if text_time_step_attended[bidx] < context_tensors['text_lens'][bidx] - 3: - # This means the sentence has definitely not ended - if bidx not in end_indices: - unfinished_texts[bidx] = True - - if text_time_step_attended[bidx] >= context_tensors['text_lens'][bidx] - 5 or bidx in end_indices: - if bidx not in finished_texts_counter: - finished_texts_counter[bidx] = 0 - - for key in finished_texts_counter: - finished_texts_counter[key] += 1 - if finished_texts_counter[key] > 10: - # We should allow EOS to be predicted now. - unfinished_texts[bidx] = False + text_time_step_attended, attended_timestep_counter = self.get_most_attended_text_timestep( + alignment_attention_scores=alignment_attention_scores, + last_attended_timesteps=last_attended_timesteps, + text_lens=context_tensors['text_lens'], + lookahead_window_size=lookahead_window_size, + attended_timestep_counter=attended_timestep_counter, + batch_size=batch_size + ) + last_attended_timesteps.append(text_time_step_attended) + _attn_prior, unfinished_texts, finished_texts_counter = self.construct_inference_prior( + prior_epsilon=prior_epsilon, + cross_attention_scores=cross_attention_scores, + text_lens=context_tensors['text_lens'], + text_time_step_attended=text_time_step_attended, + attended_timestep_counter=attended_timestep_counter, + unfinished_texts=unfinished_texts, + finished_texts_counter=finished_texts_counter, + end_indices=end_indices, + batch_size=batch_size + ) finished_items = {k: v for k, v in finished_texts_counter.items() if v >= 20} # Items that have been close to the end for atleast 20 timesteps unifinished_items = {k: v for k, v in unfinished_texts.items() if v} all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) - audio_codes_next = self.sample_codes_from_logits(all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) - all_codes_next_argmax = self.sample_codes_from_logits(all_code_logits_t, temperature=0.01, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) - + if self.cfg.get('use_local_transformer', False) and use_local_transformer_for_inference: + audio_codes_next = self.sample_codes_from_local_transformer( + dec_output=dec_out[:,-1,:], + temperature=temperature, + topk=topk, + unfinished_items=unifinished_items, + finished_items=finished_items, + use_cfg=use_cfg, + cfg_scale=cfg_scale + ) + all_codes_next_argmax = audio_codes_next + else: + audio_codes_next = self.sample_codes_from_logits(all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) + all_codes_next_argmax = self.sample_codes_from_logits(all_code_logits_t, temperature=0.01, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) for item_idx in range(all_codes_next_argmax.size(0)): if item_idx not in end_indices: @@ -1066,40 +1389,35 @@ def infer_batch( # Codec must be of atleast 4 timesteps to be decoded properly print("All ends reached") break + tts_generation_time = time.time() - start_time + tts_generation_time_per_frame = tts_generation_time / len(all_predictions) predicted_codes = torch.stack(all_predictions, dim=-1) # (B, num_codebooks, T') predicted_lens = [end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0))] # Ensure that the codec is atleast of length 4 predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long() predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens) - + end_time = time.time() + total_audio_duration_generated = (predicted_audio_lens.max().item() * predicted_audio_lens.shape[0])/self._codec_model.sample_rate + rtf = total_audio_duration_generated / (end_time - start_time) + rtf_metrics = { + 'rtf': rtf, + 'time_to_first_prediction': time_to_first_prediction, + 'tts_generation_time': tts_generation_time, + 'max_frames_generated': len(all_predictions), + 'tts_generation_time_per_frame': tts_generation_time_per_frame, + 'batch_size': text.size(0), + } torch.cuda.empty_cache() if return_cross_attn_probs: - cross_attention_scores_all_timesteps = torch.stack(cross_attention_scores_all_timesteps, dim=2) # B, text_timesteps, T' - - headwise_cross_attention_scores_all_timesteps = [] - for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): - head_cross_attention_all_timesteps = torch.stack([x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2) # B, text_timesteps, T' - headwise_cross_attention_scores_all_timesteps.append(head_cross_attention_all_timesteps) - - cross_attention_maps = [] - headwise_cross_attention_maps = [] - for bidx in range(predicted_audio.size(0)): - item_cross_attention_scores = cross_attention_scores_all_timesteps[bidx,:context_tensors['text_lens'][bidx],:predicted_codes_lens[bidx]] - cross_attn_np = plot_alignment_to_numpy(item_cross_attention_scores.cpu().numpy()) - cross_attention_maps.append(cross_attn_np) - item_all_head_cross_attn_maps = [] - if compute_all_heads_attn_maps: - for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): - item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][bidx,:context_tensors['text_lens'][bidx],:predicted_codes_lens[bidx]] - headwise_cross_attn_np = plot_alignment_to_numpy(item_headwise_cross_attention_scores.cpu().numpy()) - item_all_head_cross_attn_maps.append(headwise_cross_attn_np) - headwise_cross_attention_maps.append(item_all_head_cross_attn_maps) - - return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, cross_attention_maps, headwise_cross_attention_maps + cross_attention_maps, headwise_cross_attention_maps = self.get_inference_attention_plots( + cross_attention_scores_all_timesteps, all_heads_cross_attn_scores_all_timesteps, + context_tensors['text_lens'], predicted_codes_lens, text.size(0), compute_all_heads_attn_maps + ) + return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, rtf_metrics, cross_attention_maps, headwise_cross_attention_maps else: # For backward compatibility - return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens + return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, rtf_metrics def test_step(self, batch, batch_idx): with torch.no_grad(): @@ -1108,7 +1426,7 @@ def test_step(self, batch, batch_idx): topk = self.cfg.get('inference_topk', 80) use_cfg = self.cfg.get('inference_use_cfg', False) cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) - predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens = self.infer_batch( + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( batch, max_decoder_steps=self.cfg.get('max_decoder_steps', 500), temperature=temperature, @@ -1139,9 +1457,14 @@ def on_validation_epoch_end(self): val_loss = collect("val_loss") val_codebook_loss = collect("val_codebook_loss") val_alignment_loss = collect("val_alignment_loss") + val_aligner_encoder_loss = collect("val_aligner_encoder_loss") self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) self.log("val_codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) + self.log("val_aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True) + if self.cfg.get('use_local_transformer', False): + val_local_transformer_loss = collect("val_local_transformer_loss") + self.log("val_local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True) self.validation_step_outputs.clear() # free memory def get_dataset(self, cfg, dataset_type): @@ -1312,7 +1635,7 @@ def test_step(self, batch, batch_idx): topk = self.cfg.get('inference_topk', 80) use_cfg = self.cfg.get('inference_use_cfg', False) cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) - predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens = self.infer_batch( + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( batch, max_decoder_steps=self.cfg.get('max_decoder_steps', 500), temperature=temperature, diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index 28f498126d35..2d56b1c52465 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -227,7 +227,7 @@ def attn_naive( # attn_prior or square mask or vanilla attention if attn_prior is not None: - eps = 1e-8 + eps = self.prior_eps attn_prior = attn_prior[:, :T] # trim for inference attn_prior = torch.log(attn_prior + eps) attn_prior = attn_prior[:, None].repeat(1, self.n_heads, 1, 1) @@ -344,6 +344,7 @@ def __init__( d_model: int, d_memory: int, p_dropout: float, + prior_eps: float = 1e-8, ): """ Implements CrossAttention. See parent class for forward implementation. Must be non-causal. @@ -364,6 +365,7 @@ def __init__( raise ValueError("d_memory must be provided for cross-attention") self.q_net = torch.nn.Linear(d_model, n_heads * self.d_head, bias=False) self.kv_net = torch.nn.Linear(d_memory, 2 * n_heads * self.d_head, bias=False) + self.prior_eps = prior_eps def compute_qkv_and_mask( self, @@ -410,6 +412,7 @@ def __init__( apply_norm_to_cond: bool = True, max_length_causal_mask: int = 4096, conv_non_linearity: Callable = torch.nn.GELU(approximate="tanh"), + prior_eps: float = 1e-8, ): """ One layer of the Transformer. @@ -447,6 +450,7 @@ def __init__( d_model=d_model, d_memory=xa_d_memory, p_dropout=p_dropout, + prior_eps=prior_eps, ) if self.apply_norm_to_cond: @@ -552,6 +556,7 @@ def __init__( max_length_causal_mask: int = 4096, use_learnable_pos_emb: bool = False, conv_non_linearity: Callable = torch.nn.GELU(approximate="tanh"), + prior_eps: float = 1e-8, ): """ Initializes a stack of transformer layers. Can be used for both encoder and decoder. @@ -609,6 +614,7 @@ def __init__( apply_norm_to_cond=apply_norm_to_cond, max_length_causal_mask=max_length_causal_mask, conv_non_linearity=conv_non_linearity, + prior_eps=prior_eps, ) ) @@ -673,6 +679,7 @@ def forward( cond_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, attn_prior: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, multi_encoder_mapping: Optional[List[Optional[int]]] = None, + max_layer_idx: Optional[int] = None, ) -> Dict[str, Union[torch.Tensor, List]]: """ Args: @@ -710,6 +717,9 @@ def forward( x = out_dict['output'] attn_probabilities.append(out_dict['attn_probabilities']) + if max_layer_idx is not None and idx == max_layer_idx: + break + if self.norm_out is not None: x = self.norm_out(x) diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index 897add3b4a44..3e993000bf78 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -25,6 +25,12 @@ 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', }, + 'libri_val_12.5': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360_val.json', + 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + 'load_cached_codes_if_available': False + }, 'libri_val_shehzeen': { 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/libri360_val.json', 'audio_dir' : '/Data/LibriTTS', @@ -35,6 +41,12 @@ 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', }, + 'libri_unseen_test_12.5': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/test_clean_withContextAudioPaths.json', + 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + 'load_cached_codes_if_available': False + }, 'libri_unseen_test_shehzeen_phoneme': { 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', 'audio_dir' : '/Data/LibriTTS', diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 315c6a2e832a..917c6a8fcc8d 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -46,7 +46,8 @@ def run_inference( estimate_alignment_from_layers=None, apply_prior_to_layers=None, start_prior_after_n_audio_steps=10, - confidence_level=0.95 + confidence_level=0.95, + use_local_transformer=False ): # import ipdb; ipdb.set_trace() model_cfg = OmegaConf.load(hparams_file).cfg @@ -75,7 +76,7 @@ def run_inference( # import ipdb; ipdb.set_trace() checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] - checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}".format( + checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}_LT_{}".format( checkpoint_name, temperature, topk, @@ -86,7 +87,8 @@ def run_inference( attention_prior_lookahead_window, start_prior_after_n_audio_steps, "".join([str(l) for l in estimate_alignment_from_layers]) if estimate_alignment_from_layers is not None else "None", - "".join([str(l) for l in apply_prior_to_layers]) if apply_prior_to_layers is not None else "None" + "".join([str(l) for l in apply_prior_to_layers]) if apply_prior_to_layers is not None else "None", + use_local_transformer ) dataset_meta_info = evalset_config.dataset_meta_info for dataset in datasets: @@ -144,6 +146,7 @@ def run_inference( ) item_idx = 0 + all_rtf_metrics = [] for bidx, batch in enumerate(test_data_loader): print("Processing batch {} out of {} of dataset {}".format(bidx, len(test_data_loader), dataset)) batch_cuda ={} @@ -155,7 +158,7 @@ def run_inference( import time st = time.time() - predicted_audio, predicted_audio_lens, _, _, cross_attention_maps, _ = model.infer_batch( + predicted_audio, predicted_audio_lens, _, _, rtf_metrics, cross_attention_maps, _ = model.infer_batch( batch_cuda, max_decoder_steps=440, temperature=temperature, @@ -168,9 +171,10 @@ def run_inference( lookahead_window_size=attention_prior_lookahead_window, estimate_alignment_from_layers=estimate_alignment_from_layers, apply_prior_to_layers=apply_prior_to_layers, - start_prior_after_n_audio_steps=start_prior_after_n_audio_steps + start_prior_after_n_audio_steps=start_prior_after_n_audio_steps, + use_local_transformer_for_inference=use_local_transformer ) - + all_rtf_metrics.append(rtf_metrics) et = time.time() print(f"Time taken for inference: {et-st}", predicted_audio.size()) for idx in range(predicted_audio.size(0)): @@ -193,6 +197,10 @@ def run_inference( shutil.copy(target_audio_path, os.path.join(audio_dir, f"target_audio_{item_idx}.wav")) item_idx += 1 + mean_rtf_metrics = {} + for key in all_rtf_metrics[0]: + mean_rtf_metrics[key] = float(np.mean([m[key] for m in all_rtf_metrics])) + metrics, filewise_metrics = evaluate_generated_audio.evaluate( dataset_meta[dataset]['manifest_path'], dataset_meta[dataset]['audio_dir'], @@ -206,6 +214,9 @@ def run_inference( with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f: # Indent for better readability json.dump(filewise_metrics, f, indent=4) + + with open(os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}.json"), "w") as f: + json.dump(mean_rtf_metrics, f, indent=4) all_experiment_csv = os.path.join(out_dir, "all_experiment_metrics.csv") if not os.path.exists(all_experiment_csv): @@ -234,17 +245,18 @@ def main(): parser = argparse.ArgumentParser(description='Experiment Evaluation') parser.add_argument('--hparams_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_hparams.yaml,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_hparams.yaml") parser.add_argument('--checkpoint_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_epoch302.ckpt,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_epoch305.ckpt") - parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo") - parser.add_argument('--datasets', type=str, default="libri_seen_test,libri_unseen_test") - parser.add_argument('--base_exp_dir', type=str, default="/datap/misc/eosmount4/AllKernselSize3/NewTransformer") - parser.add_argument('--draco_exp_dir', type=str, default="/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/gitrepos/experiments/NewT5TTS_FixedPosEmb/AllKernselSize3/NewTransformer") + parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo") + parser.add_argument('--datasets', type=str, default="libri_unseen_test_12.5") + parser.add_argument('--base_exp_dir', type=str, default="/datap/misc/eosmountedresson/") + parser.add_argument('--draco_exp_dir', type=str, default="/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/gitrepos/experiments/NewT5TTS_FixedPosEmb/AllKernselSize3/EdressonCodecExperiments/") parser.add_argument('--server_address', type=str, default="pneekhara@login-eos02.eos.clusters.nvidia.com") - parser.add_argument('--exp_names', type=str, default="multiencoder_small_sp_ks3_lnormapplied") - parser.add_argument('--local_ckpt_dir', type=str, default="/datap/misc/continuouscheckpoints_fixedposembrough") - parser.add_argument('--out_dir', type=str, default="/datap/misc/ContinuousEvalResults/NewTransformerKoelTTS") + parser.add_argument('--exp_names', type=str, default="koel_12.5_FPS_causal_13codebooks_codecmodel_context5sec_LTN1,koel_12.5_FPS_causal_13codebooks_codecmodel_context5sec_LTN3") + parser.add_argument('--local_ckpt_dir', type=str, default="/datap/misc/experiment_checkpoints/localtransformer") + parser.add_argument('--out_dir', type=str, default="/datap/misc/Evals/LocalTransformerAblations2") parser.add_argument('--temperature', type=float, default=0.6) parser.add_argument('--use_cfg', action='store_true') - parser.add_argument('--cfg_scale', type=float, default=1.0) + parser.add_argument('--use_local_transformer', action='store_true') + parser.add_argument('--cfg_scale', type=float, default=2.5) parser.add_argument('--apply_attention_prior', action='store_true') parser.add_argument('--attention_prior_epsilon', type=float, default=1e-3) parser.add_argument('--attention_prior_lookahead_window', type=int, default=10) @@ -252,7 +264,7 @@ def main(): parser.add_argument('--apply_prior_to_layers', type=str, default=None) parser.add_argument('--start_prior_after_n_audio_steps', type=int, default=10) parser.add_argument('--topk', type=int, default=80) - parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--num_repeats', type=int, default=1) parser.add_argument('--confidence_level', type=float, default=0.95) args = parser.parse_args() @@ -290,6 +302,7 @@ def main(): apply_prior_to_layers=apply_prior_to_layers, start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, + use_local_transformer=args.use_local_transformer ) return else: @@ -349,6 +362,7 @@ def main(): apply_prior_to_layers=apply_prior_to_layers, start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, + use_local_transformer=args.use_local_transformer ) diff --git a/t5tts_inference.ipynb b/t5tts_inference.ipynb index f60833bd8186..875809c6fa94 100644 --- a/t5tts_inference.ipynb +++ b/t5tts_inference.ipynb @@ -39,9 +39,11 @@ }, "outputs": [], "source": [ - "hparams_file = \"/datap/misc/ChallengingFinetuneLocalTraining/head1hparams.yaml\"\n", - "checkpoint_file = \"/datap/misc/ChallengingFinetuneLocalTraining/head1_epoch248.ckpt\"\n", - "codecmodel_path = \"/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo\"\n", + "hparams_file = \"/datap/misc/experiment_checkpoints/localtransformer/koel_12.5_FPS_causal_13codebooks_codecmodel_context5sec_LTN3_hparams.yaml\"\n", + "checkpoint_file = \"/datap/misc/experiment_checkpoints/localtransformer/koel_12.5_FPS_causal_13codebooks_codecmodel_context5sec_LTN3_epoch101.ckpt\"\n", + "# codecmodel_path = \"/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo\"\n", + "codecmodel_path = \"/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo\"\n", + "\n", "\n", "# Temp out dir for saving audios\n", "out_dir = \"/datap/misc/t5tts_inference_notebook_samples\"\n", @@ -178,13 +180,19 @@ }, "outputs": [], "source": [ + "import pprint\n", "# Change sample text and prompt audio/text here\n", "audio_base_dir = \"/\"\n", "test_entries = [\n", " create_record(\n", - " text=\"This is an example of a regular text without any challlenging texts like repeated words or numbers just to do a sanity check.\",\n", - " context_text=\"Speaker and Emotion: | Language:en Dataset:Riva Speaker:Lindy_WIZWIKI |\",\n", - " #context_audio_filepath=\"/datap/misc/LibriTTSfromNemo/LibriTTS/test-clean/7729/102255/7729_102255_000012_000001.wav\", # Supply either context_audio_filepath or context_text, not both\n", + " text=\"The recollection of a unique event cannot, so Bergson contends, be wholly constituted by habit, and is in fact something radically different from the memory which is habit.\",\n", + "# text=\"This is a simple sentence to check my text to speech synthesis model.\",\n", + "# text=\" How? \",\n", + "# context_text=\"Speaker and Emotion: | Language:en Dataset:Riva Speaker:Lindy_WIZWIKI |\",\n", + "# Lindy_CMU_FEARFUL\n", + "# Lindy_WIZWIKI\n", + "# context_audio_filepath=\"/datap/misc/LibriTTSfromNemo/LibriTTS/test-clean/7729/102255/7729_102255_000012_000001.wav\", # Supply either context_audio_filepath or context_text, not both\n", + " context_audio_filepath=\"/datap/misc/LibriTTSfromNemo/LibriTTS/test-clean/8230/279154/8230_279154_000004_000009.wav\",\n", " ),\n", "]\n", "\n", @@ -246,45 +254,65 @@ " st = time.time()\n", " \n", " for _ in range(1):\n", - " for apply_prior in [True, False]:\n", - " predicted_audio, predicted_audio_lens, _, _, cross_attn_np, all_heads_attn_np = model.infer_batch(\n", - " batch_cuda, \n", - " max_decoder_steps=430, \n", - " temperature=0.6, \n", - " topk=80, \n", - " use_cfg=True,\n", - " cfg_scale=2.5,\n", - " prior_epsilon=0.1,\n", - " lookahead_window_size=5,\n", - " return_cross_attn_probs=True,\n", - " estimate_alignment_from_layers=[4,6,7],\n", - " apply_attention_prior=apply_prior,\n", - " apply_prior_to_layers=[3,4,5,6,7,8,9,10],\n", - " compute_all_heads_attn_maps=True,\n", - " start_prior_after_n_audio_steps=10\n", - " )\n", - " print(\"generation time\", time.time() - st)\n", - " for idx in range(predicted_audio.size(0)):\n", - " predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()\n", - " predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]\n", - " audio_path = os.path.join(out_dir, f\"predicted_audio_{item_idx}.wav\")\n", - " sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)\n", - " print(test_entries[bidx]['text'])\n", - " print(\"Prior Used?\", apply_prior)\n", - " display(Audio(audio_path))\n", - " item_idx += 1\n", - " plt.imshow(cross_attn_np[idx])\n", - " plt.show()\n", - "# for hidx, head_cross_attn in enumerate(all_heads_attn_np[idx]):\n", - "# layer_num = hidx // model.cfg.t5_decoder.xa_n_heads\n", - "# head_num = hidx % model.cfg.t5_decoder.xa_n_heads\n", - "# print(\"item, layer, head\", idx, layer_num, head_num)\n", - "# plt.imshow(all_heads_attn_np[idx][hidx])\n", - "# plt.show()\n", + " for use_local_transformer_for_inference in [False, True]:\n", + " for apply_prior in [False]:\n", + " predicted_audio, predicted_audio_lens, _, _, rtf_metrics, cross_attn_np, all_heads_attn_np = model.infer_batch(\n", + " batch_cuda, \n", + " max_decoder_steps=430, \n", + " temperature=0.6, \n", + " topk=80, \n", + " use_cfg=False,\n", + " cfg_scale=2.5,\n", + " prior_epsilon=0.1,\n", + " lookahead_window_size=5,\n", + " return_cross_attn_probs=True,\n", + " estimate_alignment_from_layers=[5],\n", + " apply_attention_prior=apply_prior,\n", + " apply_prior_to_layers=[0,1,2,3,4,5,6,7,8,9,10,11],\n", + " compute_all_heads_attn_maps=True,\n", + " start_prior_after_n_audio_steps=0,\n", + " use_local_transformer_for_inference=use_local_transformer_for_inference\n", + " )\n", + " print(\"generation time\", time.time() - st)\n", + " pprint.pprint(rtf_metrics)\n", + " for idx in range(predicted_audio.size(0)):\n", + " predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()\n", + " predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]\n", + " audio_path = os.path.join(out_dir, f\"predicted_audio_{item_idx}.wav\")\n", + " sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)\n", + " print(test_entries[bidx]['text'])\n", + " print(\"Prior Used?\", apply_prior)\n", + " print(\"use_local_transformer\", use_local_transformer_for_inference)\n", + " display(Audio(audio_path))\n", + " item_idx += 1\n", + " plt.imshow(cross_attn_np[idx])\n", + " plt.show()\n", + "# for hidx, head_cross_attn in enumerate(all_heads_attn_np[idx]):\n", + "# layer_num = hidx // model.cfg.t5_decoder.xa_n_heads\n", + "# head_num = hidx % model.cfg.t5_decoder.xa_n_heads\n", + "# print(\"item, layer, head\", idx, layer_num, head_num)\n", + "# plt.imshow(all_heads_attn_np[idx][hidx])\n", + "# plt.show()\n", "\n", - " print(\"------------------------------------\")\n", - " print(\"------------------------------------\")" + " print(\"------------------------------------\")\n", + " print(\"------------------------------------\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7adb9800", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b20dab52", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -303,7 +331,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.0" + "version": "3.10.12" } }, "nbformat": 4, From fb30e3840d730522804ea8473f3e9f2556a7247a Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Thu, 13 Mar 2025 16:38:18 -0700 Subject: [PATCH 006/113] Koel onlinepo, GRPO (#54) Updated by Jason, added back inference class * wavlm speaker eval Signed-off-by: Shehzeen Hussain * connect to inference script Signed-off-by: Shehzeen Hussain * bug fix Signed-off-by: Shehzeen Hussain * grpo started, training seems to be working Signed-off-by: Shehzeen Hussain * grpo local training seems ok Signed-off-by: Shehzeen Hussain * only one generation per item in val Signed-off-by: Shehzeen Hussain * allow cfg use during generation process Signed-off-by: Shehzeen Hussain * fix cer threshold for 0 reward Signed-off-by: Shehzeen Hussain * use kv cache for grpo generation Signed-off-by: Shehzeen Hussain * remove kv cache for now Signed-off-by: Shehzeen Hussain * kv cache for online po configurable Signed-off-by: Shehzeen Hussain * configurable reward params Signed-off-by: Shehzeen Hussain * grpo val set added in evalset Signed-off-by: Shehzeen Hussain * comments update Signed-off-by: Shehzeen Hussain * modify reward scaling Signed-off-by: Shehzeen Hussain * moved preference optimization code and classes to a new file Signed-off-by: Shehzeen Hussain * missing file Signed-off-by: Shehzeen Hussain * added language option in online PO Signed-off-by: Shehzeen Hussain * some updates in the script Signed-off-by: Shehzeen Hussain * add reference free option Signed-off-by: Shehzeen Hussain * handle corner cases Signed-off-by: Shehzeen Hussain * bug fix in reference free mode and torch.load fix for new container Signed-off-by: Shehzeen Hussain * added option for pesq reward Signed-off-by: Shehzeen Hussain * pesq device bug fix Signed-off-by: Shehzeen Hussain --------- Signed-off-by: Shehzeen Hussain --- examples/tts/magpietts.py | 15 +- examples/tts/t5tts_commands.md | 462 ++++++++++ .../tts/data/text_to_speech_dataset.py | 7 + nemo/collections/tts/models/__init__.py | 4 + nemo/collections/tts/models/magpietts.py | 329 +------ .../models/t5tts_preference_optimization.py | 815 ++++++++++++++++++ nemo/core/classes/modelPT.py | 2 +- .../magpietts/dpo/create_text_contextpairs.py | 48 +- scripts/magpietts/evalset_config.py | 5 + scripts/magpietts/infer_and_evaluate.py | 21 +- scripts/t5tts/evaluate_generated_audio.py | 253 ++++++ 11 files changed, 1603 insertions(+), 358 deletions(-) create mode 100644 examples/tts/t5tts_commands.md create mode 100644 nemo/collections/tts/models/t5tts_preference_optimization.py create mode 100644 scripts/t5tts/evaluate_generated_audio.py diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index af6a6f9a1752..a9de63ca5582 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -15,7 +15,7 @@ import pytorch_lightning as pl from omegaconf import OmegaConf, open_dict -from nemo.collections.tts.models import MagpieTTS_Model, MagpieTTS_ModelDPO, MagpieTTS_ModelInference +from nemo.collections.tts.models import MagpieTTS_Model, MagpieTTS_ModelInference, T5TTS_Model_PrefDataGen, T5TTS_Model_OfflinePO, T5TTS_Model_OnlinePO from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @@ -34,19 +34,26 @@ def main(cfg): if cfg.get('mode', 'train') == 'train': model = MagpieTTS_Model(cfg=cfg.model, trainer=trainer) - elif cfg.get('mode', 'dpo_train') == 'dpo_train': + elif cfg.get('mode', 'train') == 'dpo_train': model_cfg = cfg.model with open_dict(model_cfg): model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt - model = MagpieTTS_ModelDPO(cfg=model_cfg, trainer=trainer) + model = T5TTS_Model_OfflinePO(cfg=model_cfg, trainer=trainer) + elif cfg.get('mode', 'train') == 'onlinepo_train': + model_cfg = cfg.model + with open_dict(model_cfg): + model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt + model = T5TTS_Model_OnlinePO(cfg=model_cfg, trainer=trainer) elif cfg.get('mode', 'train') == 'test': model = MagpieTTS_ModelInference(cfg=cfg.model, trainer=trainer) + # elif cfg.get('mode', 'train') == 'test': + # model = T5TTS_Model_PrefDataGen(cfg=cfg.model, trainer=trainer) else: raise NotImplementedError(f"Only train, dpo_train and test modes are supported. Got {cfg.mode}") model.maybe_init_from_pretrained_checkpoint(cfg=cfg) - if cfg.get('mode', 'train') in ['train', 'dpo_train']: + if cfg.get('mode', 'train') in ['train', 'dpo_train', 'onlinepo_train']: trainer.fit(model) elif cfg.get('mode', 'train') == 'test': trainer.test(model) diff --git a/examples/tts/t5tts_commands.md b/examples/tts/t5tts_commands.md new file mode 100644 index 000000000000..4b76ceb157fb --- /dev/null +++ b/examples/tts/t5tts_commands.md @@ -0,0 +1,462 @@ +## Docker Container + +`nvcr.io/nvidia/nemo:dev` or `nvcr.io/nvidia/nemo:24.07` + +I have an sqsh file already built for `nvcr.io/nvidia/nemo:dev` - much faster to start on eos/draco than giving the above docker containers. +Sqsh path on EOS: `/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/launchscripts/nemodevNov24.sqsh` + +Docker commands I run locally + +``` +docker run --runtime=nvidia -it --rm -v --shm-size=16g --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v /home/pneekhara/2023:/home/pneekhara/2023 -v /datap/misc/:/datap/misc/ -v ~/.cache/torch:/root/.cache/torch -v ~/.netrc:/root/.netrc -v ~/.ssh:/root/.ssh --net=host nvcr.io/nvidia/nemo:dev +``` + +``` +cd /home/pneekhara/2023/SimpleT5NeMo/NeMo; export PYTHONPATH="/home/pneekhara/2023/SimpleT5NeMo/NeMo.:${PYTHONPATH}" ; +``` + +## Code +* Model `nemo/collections/tts/models/t5tts.py` +* Dataset Class `T5TTSDataset` in `nemo/collections/tts/data/text_to_speech_dataset.py` +* Transformer Module `nemo/collections/tts/modules/t5tts_transformer.py` +* Config Yaml `examples/tts/conf/t5tts/t5tts.yaml` +* Training/Inference Script `examples/tts/t5tts.py` + +## Model Types + +Currently supports three model types `single_encoder_sv_tts` , `multi_encoder_context_tts` and `decoder_context_tts` (`cfg.model.model_type` in t5tts.yaml) + +1. `single_encoder_sv_tts` is a simple T5 model: Text goes into the encoder and target audio goes to the decoder. + Additionally, speaker_embedding of target audio (or context audio if provided) from TitaNet gets added to encoder output (all timesteps). Text context is not supported in this model. + +2. `multi_encoder_context_tts` is a multi-encoder T5 model: Transcript and context audio go to different encoders. + Transcript encoding feeds to layers given by `cfg.model.transcript_decoder_layers` and the context encoding feeds into the layers given by `context_decoder_layers` . + Also supports text context which gets encoded by the same encoder as context audio. Only one of context audio or contex text is supported. + +3. `decoder_context_tts` : Text goes into the encoder; context & target audio go to the decoder. + Also supports text context. Currently, I have tested the model with using fixed sized context so I set `context_duration_min` and `context_duration_max` to the same value (5 seconds). Text context, which is usually shorter than number of codec frames of 5 second of audio, is padded to the max context duration in this model. + +4. `decoder_pretrain_synthesizer` : This is the model type used for pretraining the decoder only on audio data using next frame prediction loss. + +## Training + +### Manifest structure +For `single_encoder_sv_tts`, the manifest json files should contain the following keys: `audio_filepath, duration, text, speaker` . `speaker` is not currently being used so can be anything. Optionally, we can have a `context_audio_filepath` and `context_audio_duration` as well, if we want to use that for speaker embedding instead of the `audio_filepath`. +If we have already extracted the audio codes then they can also contain the key `target_audio_codes_path` pointing to the absolute path to the codes .pt file of shape (8, T). +Note: `target_audio_codes_path` should either be present in ALL training manifests or absent in ALL training manifest. Train set cannot be a mix of both. Same goes for val set. +If `target_audio_codes_path` is not present, codes are extracted on the fly (and training will be slower). + +For `multi_encoder_context_tts`, `decoder_context_tts`, in addition to the above, the manifest should contain `context_audio_filepath` and `context_audio_duration`. If we have codes already extracted, we can have `context_audio_codes_path` (abosolute path) instead of `context_audio_filepath`. + +For text context training, we can have `context_text` key for text context and drop `context_audio_duration` and `context_audio_filepath` (or `context_audio_codes_path`). + +If we have both `audio_filepath` and `target_audio_codes_path` in the manifest, the dataloader will load from `target_audio_codes_path`. To disable this and extract codes on the fly set the parameter `model.load_cached_codes_if_available=false` during training. Same goes for context audio. + +### Manifests and Datasets + +Manifests can be found in: `/lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/TTS/manifests` on draco-oci (`draco-oci-dc-02.draco-oci-iad.nvidia.com`) +I use the following for training. + +``` +Train: +hifitts__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json +rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json +rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_textContextsimplet5_withContextAudioPaths.json +libri100__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json +libri360__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json +mls17k__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_verified_simplet5_withContextAudioPaths.json + +Val: +dev_clean_withContextAudioPaths.json +``` + +Audio File Directories: +``` +HifiTTS: /lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/TTS/hi_fi_tts_v0 +Libri100, Libri360 Libri dev: /lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/TTS/LibriTTS +Lindy/Rodney: /lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/TTS/riva +MLS Audio: /lustre/fsw/portfolios/edgeai/projects/edgeai_riva_rivamlops/data/tts/datasets/mls17k/filtered_24khz/audio_24khz +``` + +Pre-extracted Audio Codes (21 FPS with WavLM) +``` +/lustre/fs11/portfolios/edgeai/projects/edgeai_riva_rivamlops/data/tts/datasets/codecs +``` + +### Command +``` +python examples/tts/t5tts.py \ +--config-name=t5tts \ +max_epochs=1000 \ +weighted_sampling_steps_per_epoch=1000 \ +exp_manager.exp_dir="/datap/misc/Experiments/SimpleT5Explore/LocalTraining_LRH/" \ ++train_ds_meta.rivatrain.manifest_path="/home/pneekhara/2023/SimpleT5NeMo/manifests/rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5.json" \ ++train_ds_meta.rivatrain.audio_dir="/datap/misc/Datasets/riva" \ ++train_ds_meta.rivatrain.feature_dir="/datap/misc/Datasets/riva" \ ++train_ds_meta.rivatrain.sample_weight=1.0 \ ++train_ds_meta.libri360train.manifest_path="/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5.json" \ ++train_ds_meta.libri360train.audio_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ ++train_ds_meta.libri360train.feature_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ ++train_ds_meta.libri360train.sample_weight=1.0 \ ++train_ds_meta.libri100train.manifest_path="/home/pneekhara/2023/SimpleT5NeMo/manifests/libri100__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5.json" \ ++train_ds_meta.libri100train.audio_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ ++train_ds_meta.libri100train.feature_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ ++train_ds_meta.libri100train.sample_weight=1.0 \ ++val_ds_meta.librival.manifest_path="/home/pneekhara/2023/SimpleT5NeMo/manifests/dev_clean_withcontext.json" \ ++val_ds_meta.librival.audio_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ ++val_ds_meta.librival.feature_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ +model.model_type="single_encoder_sv_tts" \ +model.use_text_conditioning_encoder=true \ +model.codecmodel_path="/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.alignment_loss_scale=0.005 \ +model.prior_scaling_factor=0.5 \ +model.prior_scaledown_start_step=5000 \ +model.prior_end_step=8000 \ +model.context_duration_min=3.0 \ +model.context_duration_max=8.0 \ +model.train_ds.dataloader_params.num_workers=2 \ +model.validation_ds.dataloader_params.num_workers=2 \ +trainer.val_check_interval=500 \ +trainer.devices=-1 \ +~model.optim.sched ; +``` + +Audio filepaths in the manifests should be relative to `audio_dir`. Codec paths are absolute. + +Set `model.model_type=multi_encoder_context_tts` for Multi Encoder T5TTS or `decoder_context_tts` for decoder context and `model.use_text_conditioning_encoder=true` if you want both audio/text contexts. + +### Command Lhotse dataset +``` +python examples/tts/t5tts.py \ + --config-name=t5tts_lhotse.yaml \ + batch_size=32 \ + micro_batch_size=32 \ + max_steps=1000000 \ + limit_val_batches=20 \ + trainer.max_steps=1000000 \ + trainer.val_check_interval=500 \ + exp_manager.exp_dir="/datap/misc/Experiments/SimpleT5Explore/LocalTraining_LRH/" \ + model.codecmodel_path="/home/ecasanova/Projects/Checkpoints/Audio_codec/21Hz-no-eliz/AudioCodec_21Hz_no_eliz.nemo" \ + model.alignment_loss_scale=0.01 \ + model.prior_scaling_factor=0.5 \ + model.prior_scaledown_start_step=5000 \ + model.prior_end_step=8000 \ + model.t5_encoder.use_flash_self_attention=true \ + model.t5_encoder.use_flash_x_attention=true \ + model.t5_decoder.use_flash_self_attention=true \ + model.t5_decoder.use_flash_x_attention=false \ + trainer.devices=1 \ + ++model.load_cached_codes_if_available=False \ + ++model.num_audio_codebooks=8 \ + ++model.num_audio_tokens_per_codebook=2048 \ + ++model.codec_model_downsample_factor=1024 \ + ~model.optim.sched ; + +HYDRA_FULL_ERROR=1 PYTHONFAULTHANDLER=1 python examples/tts/t5tts.py \ + --config-name=t5tts_lhotse.yaml \ + exp_manager.exp_dir="/datap/misc/Experiments/SimpleT5Explore/LocalTraining_LRH/" \ + +exp_manager.version=0 \ + eval_batch_size=64 \ + batch_size=384 \ + micro_batch_size=24 \ + max_steps=5000000 \ + batch_duration=350 \ + limit_val_batches=25 \ + trainer.max_steps=5000000 \ + model.codecmodel_path="/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/gitrepos/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ + ++model.train_ds.dataset.input_cfg.0.type="lhotse_shar" \ + ++model.train_ds.dataset.input_cfg.0.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/hifitts_v0/" \ + ++model.train_ds.dataset.input_cfg.0.weight=1.0 \ + ++model.train_ds.dataset.input_cfg.0.tags.lang="en" \ + ++model.train_ds.dataset.input_cfg.0.tags.s2s=True \ + ++model.train_ds.dataset.input_cfg.0.tags.tokenizer_names=["english_phoneme"] \ + ++model.train_ds.dataset.input_cfg.1.type="lhotse_shar" \ + ++model.train_ds.dataset.input_cfg.1.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/libri100/" \ + ++model.train_ds.dataset.input_cfg.1.weight=1.0 \ + ++model.train_ds.dataset.input_cfg.1.tags.lang="en" \ + ++model.train_ds.dataset.input_cfg.1.tags.s2s=True \ + ++model.train_ds.dataset.input_cfg.1.tags.tokenizer_names=["english_phoneme"] \ + ++model.train_ds.dataset.input_cfg.2.type="lhotse_shar" \ + ++model.train_ds.dataset.input_cfg.2.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/rivaLindyRodney/" \ + ++model.train_ds.dataset.input_cfg.2.weight=1.0 \ + ++model.train_ds.dataset.input_cfg.2.tags.lang="en" \ + ++model.train_ds.dataset.input_cfg.2.tags.s2s=True \ + ++model.train_ds.dataset.input_cfg.2.tags.tokenizer_names=["english_phoneme"] \ + ++model.train_ds.dataset.input_cfg.3.type="lhotse_shar" \ + ++model.train_ds.dataset.input_cfg.3.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/libri360/" \ + ++model.train_ds.dataset.input_cfg.3.weight=1.0 \ + ++model.train_ds.dataset.input_cfg.3.tags.lang="en" \ + ++model.train_ds.dataset.input_cfg.3.tags.s2s=True \ + ++model.train_ds.dataset.input_cfg.3.tags.tokenizer_names=["english_phoneme"] \ + ++model.validation_ds.dataset.input_cfg.0.type="lhotse_shar" \ + ++model.validation_ds.dataset.input_cfg.0.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/LibriTTS_dev_clean/" \ + ++model.validation_ds.dataset.input_cfg.0.weight=1.0 \ + ++model.validation_ds.dataset.input_cfg.0.tags.lang="en" \ + ++model.validation_ds.dataset.input_cfg.0.tags.s2s=True \ + ++model.validation_ds.dataset.input_cfg.0.tags.tokenizer_names=["english_phoneme"] \ + model.alignment_loss_scale=0.01 \ + model.prior_scaling_factor=0.5 \ + model.prior_scaledown_start_step=5000 \ + model.prior_end_step=8000 \ + model.t5_encoder.use_flash_self_attention=true \ + model.t5_encoder.use_flash_x_attention=true \ + model.t5_decoder.use_flash_self_attention=true \ + model.t5_decoder.use_flash_x_attention=false \ + trainer.val_check_interval=50 \ + trainer.devices=8 \ + ++model.load_cached_codes_if_available=False \ + ++model.num_audio_codebooks=8 \ + ++model.num_audio_tokens_per_codebook=2048 \ + ++model.codec_model_downsample_factor=1024 \ + model.optim.lr=2e-4 \ + trainer.num_nodes=${SLURM_JOB_NUM_NODES} + +``` +Set `model.model_type=multi_encoder_context_tts` for Multi Encoder T5TTS and `model.use_text_conditioning_encoder=true` if you are doing text context training. + +If you change the codec model, make sure to adjust these model config params in `t5tts.yaml`: + +``` +model: + num_audio_codebooks: 8 + num_audio_tokens_per_codebook: 2048 # Keep atleast 4 extra for eos/bos ids + codec_model_downsample_factor: 1024 +``` + +To train then model without CTC loss and prior, set the below params: + +``` +model.alignment_loss_scale=0.0 \ +model.prior_scaling_factor=null \ +``` + +### Training sub files on cluster + +| Model Type | Cluster | Training Sub File | +|------------|---------|--------| +| multi_encoder_context_tts | draco-oci-login-01.draco-oci-iad.nvidia.com |/lustre/fsw/portfolios/llmservice/users/pneekhara/launchscripts/unnormalized_me.sub | +| decoder_context_tts | draco-oci-login-01.draco-oci-iad.nvidia.com | /lustre/fsw/portfolios/llmservice/users/pneekhara/launchscripts/unnormalizedt5_decoder.sub | +| single_encoder_sv_tts | draco-oci-login-01.draco-oci-iad.nvidia.com | /lustre/fsw/portfolios/llmservice/users/pneekhara/launchscripts/unnormalizedt5_singleencoder.sub | +| decoder_pretrain_synthesizer | login-eos | /lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/scriptsSimpleT5/newt5_pretrain.sub | + +## Pretrained Models and Results + +Paths to pretrained checkpoints and their evaluation results on some test sets can be found [here](https://docs.google.com/spreadsheets/d/16AkvAHZ-ytWYnzEx9wtOG7yLkuU2wfB8gGMiDa5sROg/edit?usp=sharing) + +## Inference and Eval + +To infer and evaluate from a given checkpoint and hparams.yaml file I use `scripts/t5tts/infer_and_evaluate.py`. To evaluate on a given manifest (same structure as discussed above), edit the `dataset_meta_info` in `scripts/t5tts/infer_and_evaluate.py` to point to the paths on your machine or add any other datasets if missing. + +``` +dataset_meta_info = { + 'vctk': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths.json', + 'audio_dir' : '/datap/misc/Datasets/VCTK-Corpus', + 'feature_dir' : '/datap/misc/Datasets/VCTK-Corpus', + }, + 'riva_challenging': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/challengingLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json', + 'audio_dir' : '/datap/misc/Datasets/riva', + 'feature_dir' : '/datap/misc/Datasets/riva', + }, + 'libri_val': { + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360_val.json', + 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', + } +} +``` + +Then run + +``` +python scripts/t5tts/infer_and_evaluate.py \ +--hparams_file \ +--checkpoint_file \ +--codecmodel_path /datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo \ +--datasets "vctk,libri_val" \ +--out_dir /datap/misc/Evals \ +--temperature 0.6 \ +--topk 80 +``` + +Ignore the other params in the file, I also use this for evaluating ongoing experiments on the cluster by copying over the checkpoints and hparams.. + +### Inference Notebook + +Inference Notebook: `t5tts_inference.ipynb` For quickly trying custom texts/contexts. + +### Offline Preference Alignment (DPO/RPO) + +Code: `nemo/collections/tts/models/t5tts_preference_optimization.py` + +Preference Alignment (DPO/RPO) involves the following steps +1) Create a list of text-context pairs for which we will generate preference data. +2) For each text-context pair generate multiple audios from a base T5-TTS checkpoint and calculate metrics (CER/SSIM) for each generation. +3) Create chosen-rejected pairs from the generated audio. +4) Finetune the base T5-TTS checkpoint on the chosen-rejected pairs. + +#### 1. Create text-context pairs +We pair a list of challenging texts with context audios from from Riva and LibriTTS dataset. We add a similar number of regular texts from LibriTTS and Riva (paired with random context audios). We also include examples with text contexts. There are other options for generating text-context pairs. + +``` +python scripts/t5tts/dpo/create_text_contextpairs.py \ + --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ + --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ + --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ + --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ + --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ + --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \ + --nsamples_perpair 6 ; +``` +Each pair is repeated `nsamples_perpair` times which specifies how many samples we want to generate for each pair. The output manifest serves as the input for the next step. + +We can also explore other options for these text-context pairs as well depending on the task. + +#### 2. Generate audios for each text-context pair + +Next, we can generate audios from a base T5-TTS checkpoint using the following command. We pass the `audio_dir` as "/" since our text context pairs contains absolute paths. Model config arguments should be modified accordingly to match the base checkpoint architecture. We can run the below command on cluster to generate audios across multiple nodes. This command saves the generated audios along with the metrics for each generation in the `exp_dir`. Each generated audio file is accompanied with a `.json` file that has the CER/SSIM metrics. + +Sample sub file on EOS: `/lustre/fsw/llmservice_nemo_speechlm/users/shehzeenh/launchscripts/newdatagendpo_decoder.sub` + +``` +python examples/tts/t5tts.py \ +--config-name=t5tts_inference \ +mode=test \ +batch_size=64 \ ++init_from_ptl_ckpt="/mountdir/checkpoints/continuouscheckpoints_ks1_ks3/decodercontext_small_282.ckpt" \ +exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282" \ ++test_ds_meta.textcontextpairs.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json" \ ++test_ds_meta.textcontextpairs.audio_dir="/" \ ++test_ds_meta.textcontextpairs.feature_dir="/" \ +model.model_type="decoder_context_tts" \ +model.t5_encoder.kernel_size=3 \ +model.t5_decoder.kernel_size=1 \ +model.context_duration_min=5.0 \ +model.context_duration_max=5.0 \ +model.use_text_conditioning_encoder=true \ +model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.alignment_loss_scale=0.002 \ +model.prior_scaling_factor=null \ +model.load_cached_codes_if_available=false \ +trainer.num_nodes=${SLURM_JOB_NUM_NODES} +``` +#### 3. Create chosen-rejected pairs from the generations + +Next, we go through the generated audio directory and create chosen-rejected pairs. + +``` +python scripts/t5tts/dpo/create_preference_pairs.py \ +--input_manifest /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json \ +--generated_audio_dir /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/audios \ +--group_size 6 \ +--cer_threshold 0.01 \ +--val_size 256 ; +``` + +`cer_threshold=0.01` means that filter out pairs in which the chosen CER > 0.01. + +This command should save train and val manifests for DPO finetuning in the base directory of the generated_audio_dir, that is, `/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/` + +#### 4. DPO Finetuning Command + +Finally, we perform DPO finetuning using the following command: + +``` +python examples/tts/t5tts.py \ +batch_size=4 \ ++init_from_ptl_ckpt="/mountdir/checkpoints/decoder_21_epoch_2.ckpt" \ ++mode="dpo_train" \ +max_epochs=10 \ +exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/TrainingsICML/decodercontext_small_282" \ +exp_manager.checkpoint_callback_params.always_save_nemo=false \ +model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \ +model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \ ++train_ds_meta.dpopreftrain.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_train_manifest.json" \ ++train_ds_meta.dpopreftrain.audio_dir="/" \ ++train_ds_meta.dpopreftrain.feature_dir="/" \ ++val_ds_meta.dpoprefval.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_val_manifest.json" \ ++val_ds_meta.dpoprefval.audio_dir="/" \ ++val_ds_meta.dpoprefval.feature_dir="/" \ ++model.dpo_beta=0.01 \ ++model.dpo_sft_loss_weight=0.0 \ +model.model_type="decoder_context_tts" \ +model.context_duration_min=5.0 \ +model.context_duration_max=5.0 \ +model.use_text_conditioning_encoder=true \ +model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.alignment_loss_scale=0.001 \ +model.prior_scaling_factor=null \ +trainer.val_check_interval=200 \ +trainer.log_every_n_steps=10 \ +model.optim.lr=2e-7 \ +~model.optim.sched \ +trainer.num_nodes=${SLURM_JOB_NUM_NODES} +``` + +Note the following overrides in the above command: + +``` ++mode="dpo_train" \ +model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \ +model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \ +``` + +Again, our manifest contain absolute paths so we specify `audio_dir="/"` . + +### Online Preference Optimization (GRPO) + +For online preference optmization, process is much simpler. + +1) Create a list of text-context pairs for which we will generate preference data (just one pair for a text-context not repeated). +We'll use the same process as above, just set `nsamples_perpair 1` in the command. +``` +python scripts/t5tts/dpo/create_text_contextpairs.py \ + --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ + --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ + --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ + --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ + --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ + --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \ + --nsamples_perpair 1 ; +``` + +2. Train using GRPO + +``` +python examples/tts/t5tts.py \ ++mode="onlinepo_train" \ ++init_from_ptl_ckpt="/Data/ICML2025_CKPTS/icml2025_base_checkpoints/decodercontext_small_sp_ks3CorrectWithPrior_onlyphoneme_epoch161.ckpt" \ +max_epochs=1000 \ +exp_manager.exp_dir="/Data/Experiments/NewT5TTSGRPO/Try3NoDropoutBeta0.01_CFG/" \ ++train_ds_meta.grpotrainnomls.manifest_path="/Data/DPOPairsInputDatav2/text_context_pairs_grpo_train_nomls.json" \ ++train_ds_meta.grpotrainnomls.audio_dir="/" \ ++train_ds_meta.grpotrainnomls.feature_dir="/" \ ++val_ds_meta.grpovalnomls.manifest_path="/Data/DPOPairsInputDatav2/text_context_pairs_grpo_val_unseenspeakers_tinysubset.json" \ ++val_ds_meta.grpovalnomls.audio_dir="/" \ ++val_ds_meta.grpovalnomls.feature_dir="/" \ ++model.num_generations_per_item=6 \ ++model.grpo_beta=0.01 \ +model.t5_decoder.p_dropout=0.0 \ +model.t5_encoder.p_dropout=0.0 \ +model.model_type="decoder_context_tts" \ +model.use_text_conditioning_encoder=true \ +model.context_duration_min=5.0 \ +model.context_duration_max=5.0 \ +model.codecmodel_path="/Data/Checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.alignment_loss_scale=0.0 \ +model.prior_scaling_factor=null \ +model.train_ds.dataloader_params.num_workers=0 \ +model.validation_ds.dataloader_params.num_workers=0 \ +exp_manager.checkpoint_callback_params.monitor="val_mean_reward" \ +exp_manager.checkpoint_callback_params.mode="max" \ ++trainer.use_distributed_sampler=False \ ++model.inference_cfg_prob=0.5 \ ++model.inference_cfg_scale=2.5 \ +batch_size=2 \ +model.optim.lr=1e-6 \ +trainer.devices=2 \ +trainer.log_every_n_steps=1 \ +trainer.val_check_interval=50 \ +~model.optim.sched ; +``` \ No newline at end of file diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 63aedb3eea46..83b1145673b0 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -452,6 +452,9 @@ def __getitem__(self, index): example['audio_codes'] = audio_codes example['audio_codes_len'] = audio_codes_len example['audio_filepath'] = audio_codes_path + if 'audio_filepath' in data.manifest_entry: + # If audio_filepath is available, then use the actual audio file path. + example['audio_filepath'] = data.manifest_entry['audio_filepath'] else: # Only load audio if codes are not available audio_array, _, audio_filepath_rel = load_audio( @@ -604,6 +607,7 @@ def __getitem__(self, index): example["align_prior"] = align_prior example['raw_text'] = data.text + example['language'] = data.manifest_entry.get('language', 'en') if "reward" in data.manifest_entry: example["reward"] = data.manifest_entry["reward"] @@ -631,10 +635,12 @@ def collate_fn(self, batch: List[dict]): context_has_text_context_list = [] reward_list = [] raw_text_list = [] + language_list = [] for example in batch: dataset_name_list.append(example["dataset_name"]) audio_filepath_list.append(example["audio_filepath"]) raw_text_list.append(example["raw_text"]) + language_list.append(example["language"]) token_list.append(example["tokens"]) token_len_list.append(example["text_len"]) @@ -677,6 +683,7 @@ def collate_fn(self, batch: List[dict]): batch_dict = { "dataset_names": dataset_name_list, "raw_texts": raw_text_list, + "languages": language_list, "audio_filepaths": audio_filepath_list, "text": batch_tokens, "text_lens": batch_token_len, diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index 2dba794253a6..c62800149460 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -22,6 +22,7 @@ from nemo.collections.tts.models.radtts import RadTTSModel from nemo.collections.tts.models.spectrogram_enhancer import SpectrogramEnhancerModel from nemo.collections.tts.models.ssl_tts import SSLDisentangler +from nemo.collections.tts.models.t5tts_preference_optimization import T5TTS_Model_PrefDataGen, T5TTS_Model_OfflinePO, T5TTS_Model_OnlinePO from nemo.collections.tts.models.tacotron2 import Tacotron2Model from nemo.collections.tts.models.two_stages import GriffinLimModel, MelPsuedoInverseModel, TwoStagesModel from nemo.collections.tts.models.univnet import UnivNetModel @@ -42,6 +43,9 @@ "MagpieTTS_Model", "MagpieTTS_ModelInference", "MagpieTTS_ModelDPO", + "T5TTS_Model_PrefDataGen", + "T5TTS_Model_OfflinePO", + "T5TTS_Model_OnlinePO", "Tacotron2Model", "TwoStagesModel", "UnivNetModel", diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index ebcf98d984e8..719c786fd475 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -46,6 +46,12 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging +HAVE_WANDB = True +try: + import wandb +except ModuleNotFoundError: + HAVE_WANDB = False + def setup_tokenizers(all_tokenizers_config, use_text_conditioning_tokenizer, mode='train'): # Being used in both model and worker_init_fn, so it is defined here @@ -1732,326 +1738,3 @@ def test_step(self, batch, batch_idx): ) as f: json.dump(item_metrics, f) - -class MagpieTTS_ModelDPO(MagpieTTS_Model): - """Extends MagpieTTS_Model to support Direct Preference Optimization (DPO) training. - This class is used for training the model with preference-based losses, including DPO, RPO, and IPO losses. - It maintains a frozen reference model to compare log probabilities between policy and reference outputs. - - """ - - def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): - """Initialize the MagpieTTS_ModelDPO class. - - Args: - cfg (DictConfig): Configuration object containing model hyperparameters. - trainer (Trainer, optional): Trainer instance for model training. - """ - super().__init__(cfg, trainer) - # Create a copy of the configuration for the reference model - ref_model_cfg = copy.deepcopy(cfg) - with open_dict(ref_model_cfg): - ref_model_cfg.train_ds = None - ref_model_cfg.validation_ds = None - - # Initialize the frozen reference model - self._reference_model = MagpieTTS_Model(cfg=ref_model_cfg) - print("Loading reference model from checkpoint") - self._reference_model.load_state_dict( - torch.load(cfg.reference_model_ckpt_path, map_location="cpu")['state_dict'] - ) - self.freeze_model(self._reference_model) - self._reference_model.eval() - self._reference_model._no_state_dict = True - print("Reference model loaded and frozen") - - def state_dict(self, destination=None, prefix='', keep_vars=False): - """Return the state dictionary excluding non-trainable components. - - Excludes state keys related to `_speaker_verification_model`, `_codec_model`, and `_reference_model`. - - Args: - destination (dict, optional): The destination dictionary for the state_dict. - prefix (str, optional): Prefix to prepend to keys. - keep_vars (bool, optional): If True, tensors in the returned dictionary will not be detached. - - Returns: - dict: Filtered state dictionary. - """ - state_dict = super().state_dict(destination, prefix, keep_vars) - keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model'] - for key in list(state_dict.keys()): - if any([substring in key for substring in keys_substrings_to_exclude]): - del state_dict[key] - return state_dict - - def _get_batch_logps(self, logits, labels, loss_mask, average_log_prob=False): - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. - Shape: (batch_size, sequence_length) - average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return - the sum of the log probabilities of the (non-masked) tokens. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under - the given logits. - """ - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def preference_loss( - self, - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - chosen_gt_rewards=None, - rejected_gt_rewards=None, - beta=0.2, - gt_reward_scale=1.0, - label_smoothing=0, - loss_type="dpo", - reference_free=False, - ): - """Compute the DPO loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: Log probabilities of the policy model for the chosen responses. - Shape: (batch_size,) - policy_rejected_logps: Log probabilities of the policy model for the rejected responses. - Shape: (batch_size,) - reference_chosen_logps: Log probabilities of the reference model for the chosen responses. - Shape: (batch_size,) - reference_rejected_logps: Log probabilities of the reference model for the rejected responses. - Shape: (batch_size,) - beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore - the reference model as beta -> 0. - label_smoothing: conservativeness for DPO loss, which assumes that preferences are noisy (flipped with - probability label_smoothing) - ipo: If True, use the IPO loss instead of the DPO loss. - reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model - that assigns equal probability to all responses. - - Returns: - A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). - The losses tensor contains the DPO loss for each example in the batch. - The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected - responses, respectively. - """ - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - if reference_free: - ref_logratios = 0 - - logits = pi_logratios - ref_logratios # also known as h_{\pi_\theta}^{y_w,y_l} - # logits = (policy_chosen_logps - policy_rejected_logps) - (reference_chosen_logps - reference_rejected_logps) - # logits = (policy_chosen_logps - reference_chosen_logps) - (policy_rejected_logps - reference_rejected_logps) - # logits is the same as rewards_delta in NeMo aligner - # https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 - - if loss_type == "ipo": - losses = (logits - 1 / (2 * beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf - elif loss_type == "rpo": - # https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 - logbeta_hat_chosen = torch.nn.functional.logsigmoid(beta * logits) - logbeta_hat_rejected = torch.nn.functional.logsigmoid(-beta * logits) - gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) - logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) - logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) - losses = torch.exp(logalpha_hat_chosen) * (logalpha_hat_chosen - logbeta_hat_chosen) + torch.exp( - logalpha_hat_rejected - ) * (logalpha_hat_rejected - logbeta_hat_rejected) - elif loss_type == "rpo_sq": - gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) - losses = (beta * logits - gt_rewards_delta) ** 2 - elif loss_type == "dpo": - # Eq. 3 https://ericmitchell.ai/cdpo.pdf; - # label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) - F = torch.nn.functional - losses = ( - -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing - ) - else: - raise NotImplementedError("loss type {} is not implemented".format(loss_type)) - - chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() - rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() - - return losses, chosen_rewards, rejected_rewards - - def process_batch_dpo(self, batch_chosen_rejected): - """Process a batch for Direct Preference Optimization (DPO) training. - - This method computes the preference loss by comparing the model's policy outputs with a frozen reference model. - It processes chosen and rejected samples, extracts log probabilities for each codebook, and calculates the - preference loss based on the difference in likelihoods between chosen and rejected responses. - - Args: - batch_chosen_rejected (dict): A dictionary containing two keys: - - 'chosen': The batch of chosen responses. - - 'rejected': The batch of rejected responses. - - Returns: - dict: A dictionary containing: - - 'loss': The total computed loss. - - 'pref_loss': The preference loss. - - 'sft_loss': The supervised fine-tuning loss. - - 'alignment_loss': The alignment loss, if applicable. - """ - batch_chosen = batch_chosen_rejected['chosen'] - batch_rejected = batch_chosen_rejected['rejected'] - - model_output_chosen = self.process_batch(batch_chosen) - model_output_rejected = self.process_batch(batch_rejected) - with torch.no_grad(): - reference_model_output_chosen = self._reference_model.process_batch(batch_chosen) - reference_model_output_rejected = self._reference_model.process_batch(batch_rejected) - - chosen_policy_logprobs = None - rejected_policy_logprobs = None - chosen_ref_logprobs = None - rejected_ref_logprobs = None - for codebook_idx in range(self.cfg.num_audio_codebooks): - si = codebook_idx * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook - codebook_logits_chosen = model_output_chosen['logits'][:, :, si:ei] - codebook_logits_rejected = model_output_rejected['logits'][:, :, si:ei] - - ref_codebook_logits_chosen = reference_model_output_chosen['logits'][:, :, si:ei] - ref_codebook_logits_rejected = reference_model_output_rejected['logits'][:, :, si:ei] - - codebook_labels_chosen = model_output_chosen['audio_codes_target'][:, codebook_idx] - codebook_labels_rejected = model_output_rejected['audio_codes_target'][:, codebook_idx] - - codebook_log_probs_chosen = self._get_batch_logps( - codebook_logits_chosen, codebook_labels_chosen, model_output_chosen['loss_mask'] - ) - codebook_log_probs_rejected = self._get_batch_logps( - codebook_logits_rejected, codebook_labels_rejected, model_output_rejected['loss_mask'] - ) - with torch.no_grad(): - ref_codebook_log_probs_chosen = self._get_batch_logps( - ref_codebook_logits_chosen, codebook_labels_chosen, reference_model_output_chosen['loss_mask'] - ) - ref_codebook_log_probs_rejected = self._get_batch_logps( - ref_codebook_logits_rejected, - codebook_labels_rejected, - reference_model_output_rejected['loss_mask'], - ) - - if chosen_policy_logprobs is None: - chosen_policy_logprobs = codebook_log_probs_chosen - rejected_policy_logprobs = codebook_log_probs_rejected - chosen_ref_logprobs = ref_codebook_log_probs_chosen - rejected_ref_logprobs = ref_codebook_log_probs_rejected - else: - chosen_policy_logprobs += codebook_log_probs_chosen - rejected_policy_logprobs += codebook_log_probs_rejected - chosen_ref_logprobs += ref_codebook_log_probs_chosen - rejected_ref_logprobs += ref_codebook_log_probs_rejected - - rewards_chosen = batch_chosen['rewards'] - rewards_rejected = batch_rejected['rewards'] - - assert torch.all(rewards_chosen == 1) - assert torch.all(rewards_rejected < 1) - - pref_loss, chosen_rewards, rejected_rewards = self.preference_loss( - chosen_policy_logprobs, - rejected_policy_logprobs, - chosen_ref_logprobs, - rejected_ref_logprobs, - chosen_gt_rewards=rewards_chosen, - rejected_gt_rewards=rewards_rejected, - beta=self.cfg.get('dpo_beta', 0.01), - loss_type=self.cfg.get('dpo_loss_type', 'dpo'), - ) - - pref_loss = pref_loss.mean() - sft_loss = -chosen_policy_logprobs.mean() - - pref_loss_weight = self.cfg.get('dpo_pref_loss_weight', 1.0) - sft_loss_weight = self.cfg.get('dpo_sft_loss_weight', 0.0) - loss = pref_loss_weight * pref_loss + sft_loss * sft_loss_weight - - alignment_loss = model_output_chosen['alignment_loss'] - if alignment_loss is not None: - loss += alignment_loss - - return { - 'loss': loss, - 'pref_loss': pref_loss, - 'sft_loss': sft_loss, - 'alignment_loss': alignment_loss, - } - - def training_step(self, batch, batch_idx): - """Perform a training step using DPO loss. - - Args: - batch (dict): Batch data containing chosen and rejected samples. - batch_idx (int): Index of the batch. - - Returns: - Tensor: Training loss. - """ - dpo_outputs = self.process_batch_dpo(batch) - self.log('train_loss', dpo_outputs['loss'], prog_bar=True, sync_dist=True) - self.log('train_pref_loss', dpo_outputs['pref_loss'], prog_bar=True, sync_dist=True) - self.log('train_sft_loss', dpo_outputs['sft_loss'], prog_bar=True, sync_dist=True) - return dpo_outputs['loss'] - - def validation_step(self, batch, batch_idx): - """Perform a validation step using DPO loss. - - Args: - batch (dict): Validation batch data. - batch_idx (int): Batch index. - """ - dpo_outputs = self.process_batch_dpo(batch) - - val_loss = dpo_outputs['loss'] - val_pref_loss = dpo_outputs['pref_loss'] - val_sft_loss = dpo_outputs['sft_loss'] - val_alignment_loss = dpo_outputs['alignment_loss'] - - self.validation_step_outputs.append( - { - 'val_loss': val_loss, - 'val_pref_loss': val_pref_loss, - 'val_sft_loss': val_sft_loss, - 'val_alignment_loss': val_alignment_loss, - } - ) - - def on_validation_epoch_end(self): - """Aggregate validation losses at the end of the validation epoch.""" - - def collect(key): - values = [] - for x in self.validation_step_outputs: - if x[key] is not None: - values.append(x[key]) - else: - values.append(torch.tensor(0.0, device=self.device)) - stacked_values = torch.stack(values) - return stacked_values.mean() - - val_loss = collect("val_loss") - val_pref_loss = collect("val_pref_loss") - val_sft_loss = collect("val_sft_loss") - val_alignment_loss = collect("val_alignment_loss") - self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) - self.log("val_pref_loss", val_pref_loss, prog_bar=True, sync_dist=True) - self.log("val_sft_loss", val_sft_loss, prog_bar=True, sync_dist=True) - if val_alignment_loss is not None: - self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) - self.validation_step_outputs.clear() diff --git a/nemo/collections/tts/models/t5tts_preference_optimization.py b/nemo/collections/tts/models/t5tts_preference_optimization.py new file mode 100644 index 000000000000..bd19fe3dbe88 --- /dev/null +++ b/nemo/collections/tts/models/t5tts_preference_optimization.py @@ -0,0 +1,815 @@ +import numpy as np +import torch +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from torch import nn +import os +import json +from nemo.utils import logging +import nemo.collections.asr as nemo_asr +import soundfile as sf +import librosa +import copy +from omegaconf import open_dict +import string +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors +import random +try: + import torchaudio + from torchaudio.pipelines import SQUIM_OBJECTIVE + HAVE_TORCHAUDIO = True +except ImportError: + HAVE_TORCHAUDIO = False + +from nemo.collections.tts.models import T5TTS_Model + + +class T5TTS_Model_PrefDataGen(T5TTS_Model): + """Small override to save inference metrics, used for datagen in Offline PO""" + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): + super().__init__(cfg, trainer) + if cfg.get('pref_set_language', "en") == "en": + self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b") + self.eval_asr_model.freeze() + self.eval_asr_model.eval() + + self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + self.eval_speaker_verification_model.freeze() + self.eval_speaker_verification_model.eval() + + if cfg.get('load_whisper_model', False): + from transformers import WhisperProcessor, WhisperForConditionalGeneration + self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + self.whisper_model.eval() + + def test_step(self, batch, batch_idx): + with torch.no_grad(): + test_dl_batch_size = self._test_dl.batch_size + temperature = self.cfg.get('inference_temperature', 0.7) + topk = self.cfg.get('inference_topk', 80) + use_cfg = self.cfg.get('inference_use_cfg', False) + cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens = self.infer_batch( + batch, + max_decoder_steps=self.cfg.get('max_decoder_steps', 500), + temperature=temperature, + topk=topk, + use_cfg=use_cfg, + cfg_scale=cfg_scale + ) + predicted_audio_paths = [] + audio_durations = [] + batch_invalid = False + for idx in range(predicted_audio.size(0)): + predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() + predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]] + item_idx = batch_idx * test_dl_batch_size + idx + # Save the predicted audio + log_dir = self.logger.log_dir + audio_dir = os.path.join(log_dir, 'audios') + if not os.path.exists(audio_dir): + os.makedirs(audio_dir) + audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') + audio_durations.append(len(predicted_audio_np) / self.cfg.sample_rate) + sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) + + predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) + predicted_codes_torch = predicted_codes_torch[:, :predicted_codes_lens[idx]] + torch.save(predicted_codes_torch, os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt')) + predicted_audio_paths.append(audio_path) + + if not batch_invalid: + with torch.no_grad(): + try: + if self.cfg.get("pref_set_language", "en") == "en": + pred_transcripts = self.eval_asr_model.transcribe(predicted_audio_paths, batch_size=len(predicted_audio_paths)) + pred_transcripts = [ process_text_for_cer(transcript) for transcript in pred_transcripts ] + else: + pred_transcripts = [] + for audio_path in predicted_audio_paths: + transcript = transcribe_with_whisper(audio_path, self.cfg.pref_set_language, self.whisper_processor, self.whisper_model, self.device) + pred_transcripts.append(transcript) + + pred_transcripts = [process_text_for_cer(transcript) for transcript in pred_transcripts] + except Exception as e: + assert (predicted_audio_lens[idx] < 1000).any(), f"Expected short audio file to be the only cause of ASR errors, but got error with lengths {predicted_audio_lens}" + logging.warning(f"Exception during ASR transcription: {e}") + logging.warning(f"Skipping processing of the batch; generating metrics indicating a WER of 100% and Speaker Similarity of 0.0") + batch_invalid = True + continue # don't break since we want to continue building audio durations list + pred_speaker_embeddings = get_speaker_embeddings_from_filepaths(predicted_audio_paths, self.eval_speaker_verification_model, self.device) + gt_speaker_embeddings = get_speaker_embeddings_from_filepaths(batch['audio_filepaths'], self.eval_speaker_verification_model, self.device) + + for idx in range(predicted_audio.size(0)): + if not batch_invalid: + audio_path = predicted_audio_paths[idx] + item_idx = batch_idx * test_dl_batch_size + idx + pred_transcript = pred_transcripts[idx] + gt_transcript = process_text_for_cer(batch['raw_texts'][idx]) + + cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True) + wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False) + + spk_embedding_pred = pred_speaker_embeddings[idx].cpu().numpy() + spk_embedding_gt = gt_speaker_embeddings[idx].cpu().numpy() + + spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) + ) + else: + # Create an entry indicating invalid metrics + cer_gt = 1.0 + wer_gt = 1.0 + spk_similarity = 0.0 + pred_transcript = "" # do not change this string; subsequent processing relies on it + gt_transcript = process_text_for_cer(batch['raw_texts'][idx]) + + item_metrics = { + 'cer_gt': float(cer_gt), + 'wer_gt': float(wer_gt), + 'duration' : audio_durations[idx], + 'spk_similarity': float(spk_similarity), + 'pred_transcript': pred_transcript, + 'gt_transcript': gt_transcript, + } + + with open(os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w') as f: + json.dump(item_metrics, f) + +class T5TTS_Model_OfflinePO(T5TTS_Model): + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): + super().__init__(cfg, trainer) + # Copy cfg + ref_model_cfg = copy.deepcopy(cfg) + with open_dict(ref_model_cfg): + ref_model_cfg.train_ds = None + ref_model_cfg.validation_ds = None + self._reference_model = T5TTS_Model(cfg=ref_model_cfg) + print("Loading reference model from checkpoint") + self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) + self.freeze_model(self._reference_model) + self._reference_model.eval() + self._reference_model._no_state_dict = True + print("Reference model loaded and frozen") + + def state_dict(self, destination=None, prefix='', keep_vars=False): + state_dict = super().state_dict(destination, prefix, keep_vars) + keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model'] + for key in list(state_dict.keys()): + if any([substring in key for substring in keys_substrings_to_exclude]): + del state_dict[key] + return state_dict + + + def _get_batch_logps(self, logits, labels, loss_mask, average_log_prob=False): + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + # https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py + def preference_loss(self, policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + chosen_gt_rewards=None, + rejected_gt_rewards=None, + beta=0.2, + gt_reward_scale=1.0, + label_smoothing=0, + loss_type="dpo", + reference_free=False): + """Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. + label_smoothing: conservativeness for DPO loss, which assumes that preferences are noisy (flipped with probability label_smoothing) + ipo: If True, use the IPO loss instead of the DPO loss. + reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + if reference_free: + ref_logratios = 0 + + logits = pi_logratios - ref_logratios # also known as h_{\pi_\theta}^{y_w,y_l} + # logits = (policy_chosen_logps - policy_rejected_logps) - (reference_chosen_logps - reference_rejected_logps) + # logits = (policy_chosen_logps - reference_chosen_logps) - (policy_rejected_logps - reference_rejected_logps) + # logits is the same as rewards_delta in NeMo aligner: https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 + + if loss_type == "ipo": + losses = (logits - 1/(2 * beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf + elif loss_type == "rpo": + # https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 + logbeta_hat_chosen = torch.nn.functional.logsigmoid(beta * logits) + logbeta_hat_rejected = torch.nn.functional.logsigmoid(-beta * logits) + gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) + logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) + logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) + losses = ( + torch.exp(logalpha_hat_chosen) * (logalpha_hat_chosen - logbeta_hat_chosen) + + torch.exp(logalpha_hat_rejected) * (logalpha_hat_rejected - logbeta_hat_rejected) + ) + elif loss_type == "rpo_sq": + gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) + losses = (beta * logits - gt_rewards_delta) ** 2 + elif loss_type == "dpo": + # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) + F = torch.nn.functional + losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing + else: + raise NotImplementedError("loss type {} is not implemented".format(loss_type)) + + chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards + + def process_batch_dpo(self, batch_chosen_rejected): + batch_chosen = batch_chosen_rejected['chosen'] + batch_rejected = batch_chosen_rejected['rejected'] + + model_output_chosen = self.process_batch(batch_chosen) + model_output_rejected = self.process_batch(batch_rejected) + with torch.no_grad(): + reference_model_output_chosen = self._reference_model.process_batch(batch_chosen) + reference_model_output_rejected = self._reference_model.process_batch(batch_rejected) + + chosen_policy_logprobs = None + rejected_policy_logprobs = None + chosen_ref_logprobs = None + rejected_ref_logprobs = None + for codebook_idx in range(self.cfg.num_audio_codebooks): + si = codebook_idx * self.cfg.num_audio_tokens_per_codebook + ei = si + self.cfg.num_audio_tokens_per_codebook + codebook_logits_chosen = model_output_chosen['logits'][:, :, si:ei] + codebook_logits_rejected = model_output_rejected['logits'][:, :, si:ei] + + ref_codebook_logits_chosen = reference_model_output_chosen['logits'][:, :, si:ei] + ref_codebook_logits_rejected = reference_model_output_rejected['logits'][:, :, si:ei] + + codebook_labels_chosen = model_output_chosen['audio_codes_target'][:,codebook_idx] + codebook_labels_rejected = model_output_rejected['audio_codes_target'][:,codebook_idx] + + codebook_log_probs_chosen = self._get_batch_logps(codebook_logits_chosen, codebook_labels_chosen, model_output_chosen['loss_mask']) + codebook_log_probs_rejected = self._get_batch_logps(codebook_logits_rejected, codebook_labels_rejected, model_output_rejected['loss_mask']) + with torch.no_grad(): + ref_codebook_log_probs_chosen = self._get_batch_logps(ref_codebook_logits_chosen, codebook_labels_chosen, reference_model_output_chosen['loss_mask']) + ref_codebook_log_probs_rejected = self._get_batch_logps(ref_codebook_logits_rejected, codebook_labels_rejected, reference_model_output_rejected['loss_mask']) + + if chosen_policy_logprobs is None: + chosen_policy_logprobs = codebook_log_probs_chosen + rejected_policy_logprobs = codebook_log_probs_rejected + chosen_ref_logprobs = ref_codebook_log_probs_chosen + rejected_ref_logprobs = ref_codebook_log_probs_rejected + else: + chosen_policy_logprobs += codebook_log_probs_chosen + rejected_policy_logprobs += codebook_log_probs_rejected + chosen_ref_logprobs += ref_codebook_log_probs_chosen + rejected_ref_logprobs += ref_codebook_log_probs_rejected + + rewards_chosen = batch_chosen['rewards'] + rewards_rejected = batch_rejected['rewards'] + + assert torch.all(rewards_chosen == 1) + assert torch.all(rewards_rejected < 1) + + pref_loss, chosen_rewards, rejected_rewards = self.preference_loss( + chosen_policy_logprobs, + rejected_policy_logprobs, + chosen_ref_logprobs, + rejected_ref_logprobs, + chosen_gt_rewards=rewards_chosen, + rejected_gt_rewards=rewards_rejected, + beta=self.cfg.get('dpo_beta', 0.01), + loss_type=self.cfg.get('dpo_loss_type', 'dpo'), + ) + + pref_loss = pref_loss.mean() + sft_loss = -chosen_policy_logprobs.mean() + + pref_loss_weight = self.cfg.get('dpo_pref_loss_weight', 1.0) + sft_loss_weight = self.cfg.get('dpo_sft_loss_weight', 0.0) + loss = pref_loss_weight * pref_loss + sft_loss * sft_loss_weight + + alignment_loss = model_output_chosen['alignment_loss'] + if alignment_loss is not None: + loss += alignment_loss + + return { + 'loss': loss, + 'pref_loss': pref_loss, + 'sft_loss': sft_loss, + 'alignment_loss': alignment_loss, + } + + def training_step(self, batch, batch_idx): + dpo_outputs = self.process_batch_dpo(batch) + self.log('train_loss', dpo_outputs['loss'], prog_bar=True, sync_dist=True) + self.log('train_pref_loss', dpo_outputs['pref_loss'], prog_bar=True, sync_dist=True) + self.log('train_sft_loss', dpo_outputs['sft_loss'], prog_bar=True, sync_dist=True) + return dpo_outputs['loss'] + + def validation_step(self, batch, batch_idx): + dpo_outputs = self.process_batch_dpo(batch) + + val_loss = dpo_outputs['loss'] + val_pref_loss = dpo_outputs['pref_loss'] + val_sft_loss = dpo_outputs['sft_loss'] + val_alignment_loss = dpo_outputs['alignment_loss'] + + self.validation_step_outputs.append({ + 'val_loss': val_loss, + 'val_pref_loss': val_pref_loss, + 'val_sft_loss': val_sft_loss, + 'val_alignment_loss': val_alignment_loss, + }) + + def on_validation_epoch_end(self): + def collect(key): + values = [] + for x in self.validation_step_outputs: + if x[key] is not None: + values.append(x[key]) + else: + values.append(torch.tensor(0.0, device=self.device)) + stacked_values = torch.stack(values) + return stacked_values.mean() + + val_loss = collect("val_loss") + val_pref_loss = collect("val_pref_loss") + val_sft_loss = collect("val_sft_loss") + val_alignment_loss = collect("val_alignment_loss") + self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) + self.log("val_pref_loss", val_pref_loss, prog_bar=True, sync_dist=True) + self.log("val_sft_loss", val_sft_loss, prog_bar=True, sync_dist=True) + if val_alignment_loss is not None: + self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) + self.validation_step_outputs.clear() + +class T5TTS_Model_OnlinePO(T5TTS_Model): + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): + super().__init__(cfg, trainer) + # Copy cfg + ref_model_cfg = copy.deepcopy(cfg) + with open_dict(ref_model_cfg): + ref_model_cfg.train_ds = None + ref_model_cfg.validation_ds = None + + self.reference_free = self.cfg.get('reference_free', False) # True means we dont use the reference model + if not self.reference_free: + self._reference_model = T5TTS_Model(cfg=ref_model_cfg) + print("Loading reference model from checkpoint") + self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) + self.freeze_model(self._reference_model) + self._reference_model.eval() + self._reference_model._no_state_dict = True + print("Reference model loaded and frozen") + + if cfg.get('reward_asr_model', "nemo") == "nemo": + self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b") + self.eval_asr_model.freeze() + self.eval_asr_model.eval() + elif cfg.get('reward_asr_model', "nemo") == "whisper": + from transformers import WhisperProcessor, WhisperForConditionalGeneration + self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + self.whisper_model.eval() + else: + raise ValueError(f"Unknown reward_asr_model: {cfg.reward_asr_model}") + + self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + self.eval_speaker_verification_model.freeze() + self.eval_speaker_verification_model.eval() + + if cfg.get('load_whisper_model', False): + from transformers import WhisperProcessor, WhisperForConditionalGeneration + self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + self.whisper_model.eval() + + use_pesq = self.cfg.get('use_pesq', False) + if use_pesq: + # import ipdb; ipdb.set_trace() + assert HAVE_TORCHAUDIO, "torchaudio is required for PESQ reward" + self.squim_objective_model = SQUIM_OBJECTIVE.get_model() + + def state_dict(self, destination=None, prefix='', keep_vars=False): + state_dict = super().state_dict(destination, prefix, keep_vars) + keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model', + 'eval_asr_model', 'eval_speaker_verification_model', 'whisper_model'] + for key in list(state_dict.keys()): + if any([substring in key for substring in keys_substrings_to_exclude]): + del state_dict[key] + return state_dict + + + def _get_per_token_logps(self, logits, labels, loss_mask): + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. + """ + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = per_token_logps * loss_mask + return per_token_logps + + def repeat_items_in_batch(self, batch, num_repeats): + repeated_batch = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + repeated_value = value.repeat_interleave(num_repeats, dim=0) + elif isinstance(value, list): + repeated_value = [] + for item in value: + repeated_value.extend([item] * num_repeats) + else: + repeated_value = value + repeated_batch[key] = repeated_value + return repeated_batch + + def generate_and_reward(self, batch, num_generations_per_item, mode='train'): + batch_repeated = self.repeat_items_in_batch(batch, num_generations_per_item) + temperature = self.cfg.get('inference_temperature', 0.7) + topk = self.cfg.get('inference_topk', 80) + use_cfg = False + cfg_scale = 1.0 + use_pesq = self.cfg.get('use_pesq', False) + inference_cfg_prob = self.cfg.get('inference_cfg_prob', 0.0) + if (inference_cfg_prob == 1.0) or (inference_cfg_prob > 0.0 and mode == 'train'): + # Randomly set use_cfg based on the given probability + use_cfg = random.random() < self.cfg.inference_cfg_prob + cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) + print("use_cfg", use_cfg) + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens = self.infer_batch( + batch_repeated, + max_decoder_steps=self.cfg.get('max_decoder_steps', 430), + temperature=temperature, + topk=topk, + use_cfg=use_cfg, + cfg_scale=cfg_scale + ) + predicted_audio_paths = [] + audio_durations = [] + for idx in range(predicted_audio.size(0)): + predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() + predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]] + if predicted_audio_np.shape[0] < 1000: + # Corner case to handle short audio files + predicted_audio_np = np.pad(predicted_audio_np, (0, 1000 - predicted_audio_np.shape[0])) + item_idx = idx + # Save the predicted audio + log_dir = self.logger.log_dir + audio_dir = os.path.join(log_dir, 'audios') + os.makedirs(audio_dir, exist_ok=True) + audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') + audio_durations.append(len(predicted_audio_np) / self.cfg.sample_rate) + sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) + + predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) + predicted_codes_torch = predicted_codes_torch[:, :predicted_codes_lens[idx]] # C, T + torch.save(predicted_codes_torch, os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt')) + predicted_audio_paths.append(audio_path) + + with torch.no_grad(): + if self.cfg.get("reward_asr_model", "nemo") == "nemo": + pred_transcripts = self.eval_asr_model.transcribe(predicted_audio_paths, batch_size=len(predicted_audio_paths)) + pred_transcripts = [ process_text_for_cer(transcript) for transcript in pred_transcripts ] + elif self.cfg.get("reward_asr_model", "nemo") == "whisper": + pred_transcripts = [] + for item_idx, audio_path in enumerate(predicted_audio_paths): + language = batch_repeated['languages'][item_idx] + transcript = transcribe_with_whisper(audio_path, language, self.whisper_processor, self.whisper_model, self.device) + pred_transcripts.append(transcript) + pred_transcripts = [process_text_for_cer(transcript) for transcript in pred_transcripts] + + pred_speaker_embeddings = get_speaker_embeddings_from_filepaths(predicted_audio_paths, self.eval_speaker_verification_model, self.device) + gt_speaker_embeddings = get_speaker_embeddings_from_filepaths(batch_repeated['audio_filepaths'], self.eval_speaker_verification_model, self.device) + + + batch_metrics = [] + cer_reward_weight = self.cfg.get('cer_reward_weight', 0.5) + ssim_reward_weight = self.cfg.get('ssim_reward_weight', 0.5) + pesq_reward_weight = self.cfg.get('pesq_reward_weight', 0.0) + for idx in range(predicted_audio.size(0)): + audio_path = predicted_audio_paths[idx] + item_idx = idx + pred_transcript = pred_transcripts[idx] + gt_transcript = process_text_for_cer(batch_repeated['raw_texts'][idx]) + cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True) + wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False) + spk_embedding_pred = pred_speaker_embeddings[idx].cpu().float().numpy() + spk_embedding_gt = gt_speaker_embeddings[idx].cpu().float().numpy() + spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) + ) + if use_pesq: + sample_audio, sr = torchaudio.load(audio_path) + sample_audio = sample_audio.to(self.device) + if sr != 16000: + sample_audio = torchaudio.functional.resample(sample_audio, sr, 16000) + _, pesq_hyp, _ = self.squim_objective_model(sample_audio) + pesq_hyp = pesq_hyp.item() + + item_metrics = { + 'cer_gt': float(cer_gt), + 'wer_gt': float(wer_gt), + 'duration' : audio_durations[idx], + 'spk_similarity': float(spk_similarity), + 'pred_transcript': pred_transcript, + 'gt_transcript': gt_transcript, + 'codes_len': predicted_codes_lens[idx].item(), + 'pesq' : pesq_hyp if use_pesq else 0.0, + } + with open(os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w') as f: + json.dump(item_metrics, f) + + batch_metrics.append(item_metrics) + + num_groups = len(batch['audio_filepaths']) + + best_ssim_achievable = self.cfg.get("best_ssim_achievable", 0.9) # Examples with this speaker similarity or higher will have SSIM reward of 1 + mean_cer_dataset = self.cfg.get("mean_cer_dataset", 0.1) # CER equal to this value will have reward of 0.5 + mean_ssim_dataset = self.cfg.get("mean_ssim_dataset", 0.6) # SSIM equal to this value will have reward of 0.5 + for group_idx in range(num_groups): + group_start_idx = group_idx * num_generations_per_item + group_end_idx = group_start_idx + num_generations_per_item + group_rewards = [] + mean_reward = 0 + for idx in range(group_start_idx, group_end_idx): + # Lower CER and higher speaker similarity is better, means high reward + # Higher pesq is better, means high reward + # Reward for best CER and best speaker similarity should be 1 + item_cer = batch_metrics[idx]['cer_gt'] + item_ssim = batch_metrics[idx]['spk_similarity'] + item_cer = min( max(item_cer, 0.0), 1.0) + item_ssim = max( min(item_ssim, best_ssim_achievable), 0.0) + item_pesq = batch_metrics[idx]['pesq'] + if item_cer <= mean_cer_dataset: + cer_reward = 0.5 + 0.5 * (mean_cer_dataset - item_cer) / mean_cer_dataset # 0.5 to 1 + else: + cer_reward = 0.5 - 0.5 * (item_cer - mean_cer_dataset) / (1 - mean_cer_dataset) # 0 to 0.5 + if item_ssim >= mean_ssim_dataset: + spk_similarity_reward = 0.5 + 0.5 * (item_ssim - mean_ssim_dataset) / (best_ssim_achievable - mean_ssim_dataset) + else: + spk_similarity_reward = 0.5 - 0.5 * (mean_ssim_dataset - item_ssim) / (mean_ssim_dataset) + if use_pesq: + pesq_reward = item_pesq / 4.5 + else: + pesq_reward = 0.0 + + batch_metrics[idx]['reward'] = cer_reward * cer_reward_weight + spk_similarity_reward * ssim_reward_weight + pesq_reward * pesq_reward_weight + + if (batch_metrics[idx]['codes_len'] >= 425) or (batch_metrics[idx]['codes_len'] <= 3): # TODO: Remove hardcoded lengths + # This means it did not complete the sentence or generated an extremely short sentence + batch_metrics[idx]['reward'] = 0.0 + print("Item idx: ", idx, " CER: ", item_cer, " SSIM: ", item_ssim, " Reward: ", batch_metrics[idx]['reward'], " Codes len: ", batch_metrics[idx]['codes_len']) + batch_metrics[idx]['cer_reward'] = cer_reward + batch_metrics[idx]['spk_similarity_reward'] = spk_similarity_reward + batch_metrics[idx]['pesq_reward'] = pesq_reward + mean_reward += batch_metrics[idx]['reward'] + group_rewards.append(batch_metrics[idx]['reward']) + + mean_reward /= num_generations_per_item + std_reward = np.std(group_rewards) + for idx in range(group_start_idx, group_end_idx): + batch_metrics[idx]['advantage'] = (batch_metrics[idx]['reward'] - mean_reward) / (std_reward + 1e-6) + + + advantages = [x['advantage'] for x in batch_metrics] + advantages = torch.tensor(advantages, device=self.device) + print("Mean reward: ", mean_reward) + return { + 'mean_reward': torch.tensor(mean_reward, device=self.device), + 'batch_repeated': batch_repeated, + 'metrics': batch_metrics, + 'predicted_codes': predicted_codes, + 'predicted_codes_lens': predicted_codes_lens, + 'advantages': advantages, + } + + def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): + use_kv_cache_during_online_po = self.cfg.get("use_kv_cache_during_online_po", False) + if use_kv_cache_during_online_po: + self.use_kv_cache_for_inference = True + self.t5_decoder.reset_cache(use_cache=True) + + with torch.no_grad(): + generated_codes_and_metrics = self.generate_and_reward(batch, n_generations_per_item, mode) + + if use_kv_cache_during_online_po: + self.use_kv_cache_for_inference = False + self.t5_decoder.reset_cache(use_cache=False) + + batch_repeated = generated_codes_and_metrics['batch_repeated'] + predicted_codes = generated_codes_and_metrics['predicted_codes'] # B, 8, T + predicted_codes_lens = generated_codes_and_metrics['predicted_codes_lens'] # B + predicted_codes = predicted_codes[:,:,:predicted_codes_lens.max()] + + advantages = generated_codes_and_metrics['advantages'] # B + # Add extra tokens for BOS and EOS + bos_tensor = torch.full((predicted_codes.size(0), predicted_codes.size(1), 1), self.audio_bos_id, dtype=predicted_codes.dtype, device=predicted_codes.device) + padding_tensor = torch.full((predicted_codes.size(0), predicted_codes.size(1), 1), 0, dtype=predicted_codes.dtype, device=predicted_codes.device) + predicted_codes = torch.cat([bos_tensor, predicted_codes, padding_tensor], dim=2) + for idx in range(predicted_codes.size(0)): + predicted_codes[idx, :, predicted_codes_lens[idx]+1] = self.audio_eos_id # Accounts for BOS + batch_repeated['audio_codes'] = predicted_codes + batch_repeated['audio_codes_lens'] = predicted_codes_lens + 2 # Accounts for BOS and EOS + if 'audio' in batch_repeated: + del batch_repeated['audio'] + if 'audio_lens' in batch_repeated: + del batch_repeated['audio_lens'] + + policy_model_outputs = self.process_batch(batch_repeated) + + if not self.reference_free: + with torch.no_grad(): + reference_model_output = self._reference_model.process_batch(batch_repeated) + + total_loss = None + total_kl = None + for codebook_idx in range(self.cfg.num_audio_codebooks): + si = codebook_idx * self.cfg.num_audio_tokens_per_codebook + ei = si + self.cfg.num_audio_tokens_per_codebook + codebook_logits = policy_model_outputs['logits'][:, :, si:ei] # B, T, C + codebook_labels = batch_repeated['audio_codes'][:,codebook_idx,1:] + per_token_codebook_log_probs = self._get_per_token_logps(codebook_logits, codebook_labels, policy_model_outputs['loss_mask']) + per_token_loss = -(torch.exp(per_token_codebook_log_probs - per_token_codebook_log_probs.detach()) * advantages.unsqueeze(1)) + + + if not self.reference_free: + with torch.no_grad(): + ref_codebook_logits = reference_model_output['logits'][:, :, si:ei] + per_token_ref_codebook_log_probs = self._get_per_token_logps(ref_codebook_logits, codebook_labels, reference_model_output['loss_mask']) + # https://github.com/huggingface/trl/blob/ffcb9f4aee725a2bd072d0387afe68a4b1c7967c/trl/trainer/grpo_trainer.py#L703 + per_token_codebook_kl = torch.exp(per_token_ref_codebook_log_probs - per_token_codebook_log_probs) - (per_token_ref_codebook_log_probs - per_token_codebook_log_probs) - 1 + per_token_loss = per_token_loss + self.cfg.grpo_beta * per_token_codebook_kl + codebook_kl_loss_mean = ((per_token_codebook_kl * policy_model_outputs['loss_mask']).sum(dim=1) / policy_model_outputs['loss_mask'].sum(dim=1)).mean() + else: + codebook_kl_loss_mean = torch.tensor(0.0, device=self.device) + + codebook_loss = ((per_token_loss * policy_model_outputs['loss_mask']).sum(dim=1) / policy_model_outputs['loss_mask'].sum(dim=1)).mean() + + if total_loss is None: + total_loss = codebook_loss + total_kl = codebook_kl_loss_mean + else: + total_loss += codebook_loss + total_kl += codebook_kl_loss_mean + + + total_loss /= self.cfg.num_audio_codebooks + print("Total kl", total_kl, n_generations_per_item) + return { + 'mean_reward': generated_codes_and_metrics['mean_reward'], + 'loss': total_loss, + 'kl_loss': total_kl, + 'batch_metrics': generated_codes_and_metrics['metrics'], + } + + def training_step(self, batch, batch_idx): + torch.cuda.empty_cache() + n_generations_per_item = self.cfg.get('n_generations_per_item', 6) + po_outputs = self.process_batch_online_po(batch, n_generations_per_item) + self.log('train_loss', po_outputs['loss'], prog_bar=True, sync_dist=True) + self.log('train_kl_loss', po_outputs['kl_loss'], prog_bar=True, sync_dist=True) + self.log('train_mean_reward', po_outputs['mean_reward'], prog_bar=True, sync_dist=True) + return po_outputs['loss'] + + def validation_step(self, batch, batch_idx): + po_outputs = self.process_batch_online_po(batch, 1, mode='val') + batch_metrics = po_outputs['batch_metrics'] + mean_reward = po_outputs['mean_reward'] + val_loss = po_outputs['loss'] + val_kl_loss = po_outputs['kl_loss'] + + self.validation_step_outputs.append({ + 'mean_reward': mean_reward, + 'val_loss': val_loss, + 'val_kl_loss': val_kl_loss, + 'batch_metrics': batch_metrics, + }) + + def on_validation_epoch_end(self): + def collect(key): + values = [] + for x in self.validation_step_outputs: + if x[key] is not None: + values.append(x[key]) + else: + values.append(torch.tensor(0.0, device=self.device)) + stacked_values = torch.stack(values) + return stacked_values.mean() + + val_loss = collect("val_loss") + val_kl_loss = collect("val_kl_loss") + mean_reward = collect("mean_reward") + + self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) + self.log("val_kl_loss", val_kl_loss, prog_bar=True, sync_dist=True) + self.log("val_mean_reward", mean_reward, prog_bar=True, sync_dist=True) + + mean_metrics = {} + for val_output in self.validation_step_outputs: + batch_metrics = val_output['batch_metrics'] + for item_metrics in batch_metrics: + for key, value in item_metrics.items(): + if "transcript" not in key: + if key not in mean_metrics: + mean_metrics[key] = [] + mean_metrics[key].append(value) + + for key, values in mean_metrics.items(): + mean_metrics[key] = np.mean(values) + self.log(f"val_{key}", mean_metrics[key], prog_bar=True, sync_dist=True) + + self.validation_step_outputs.clear() + + + +# Utility functions +def process_text_for_cer(input_text): + """ + Normalizes text for CER/WER calculation. + Taken from hallucination_eval.py + """ + # Convert text to lowercase + lower_case_text = input_text.lower() + + # Remove commas from text + no_comma_text = lower_case_text.replace(",", "") + # Replace "-" with spaces + no_dash_text = no_comma_text.replace("-", " ") + no_dash_text = no_dash_text.replace("'", "") + no_dash_text = no_dash_text.replace(";", "") + no_dash_text = no_dash_text.replace(".", "") + + # Replace double spaces with single space + single_space_text = " ".join(no_dash_text.split()) + + single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) + + # @shehzeen: Added this to handle some common errors in ASR transcripts + single_space_text.replace("h t t p", "http") + single_space_text.replace("w w w", "www") + + return single_space_text + +def get_speaker_embeddings_from_filepaths(filepaths, speaker_verification_model, device): + audio_batch = [] + audio_lengths = [] + for filepath in filepaths: + audio, sr = sf.read(filepath) + if sr != 16000: + audio = librosa.core.resample(audio, orig_sr=sr, target_sr=16000) + audio_tensor = torch.tensor(audio, dtype=torch.float32, device=device) + audio_batch.append(audio_tensor) + audio_lengths.append(audio_tensor.size(0)) + + batch_audio_lens = torch.tensor(audio_lengths, device=device).long() + max_audio_len = int(batch_audio_lens.max().item()) + audio_batch = stack_tensors(audio_batch, max_lens=[max_audio_len]) + + _, speaker_embeddings = speaker_verification_model.forward( + input_signal=audio_batch, + input_signal_length=batch_audio_lens + ) + + return speaker_embeddings + +def transcribe_with_whisper(audio_filepath, language, whisper_processor, whisper_model, device): + print("Transcribing with whisper", language) + speech_array, sampling_rate = librosa.load(audio_filepath, sr=16000) + forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language) if language else None + inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features + inputs = inputs.to(device) + with torch.no_grad(): + predicted_ids = whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) + transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) + result = transcription[0] + return result \ No newline at end of file diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 78f0715251be..fae5319fcc8d 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -1408,7 +1408,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st if isinstance(cfg.init_from_ptl_ckpt, str): # Restore checkpoint ckpt_path = cfg.pop('init_from_ptl_ckpt') - ckpt = torch.load(ckpt_path, map_location=map_location) + ckpt = torch.load(ckpt_path, map_location=map_location, weights_only=False) # Restore checkpoint into current model self.load_state_dict(ckpt['state_dict'], strict=False) diff --git a/scripts/magpietts/dpo/create_text_contextpairs.py b/scripts/magpietts/dpo/create_text_contextpairs.py index 74ee5ff6b92f..19f00835c043 100644 --- a/scripts/magpietts/dpo/create_text_contextpairs.py +++ b/scripts/magpietts/dpo/create_text_contextpairs.py @@ -37,12 +37,17 @@ def main(): Example usage: python scripts/t5tts/dpo/create_text_contextpairs.py \ - --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ - --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ - --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ - --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ - --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ - --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json + --challenging_texts /Data/DPOPairsInputDatav2/challenging_with_short.txt \ + --regular_texts_for_audiocontext /Data/DPOPairsInputDatav2/regular_texts_for_audiocontext.txt \ + --regular_texts_for_textcontext /Data/DPOPairsInputDatav2/regular_texts_for_textcontext.txt \ + --audio_contexts /Data/DPOPairsInputDatav2/audio_context_list.json \ + --text_contexts /Data/DPOPairsInputDatav2/text_context_list_with_audio.txt \ + --output_manifest /Data/DPOPairsInputDatav2/grpo_train_with_short.json \ + --n_audio_contexts_per_challenging_text 2 \ + --n_text_contexts_per_challenging_text 2 \ + --n_audio_contexts_per_regular_text 1 \ + --n_text_contexts_per_regular_text 1 \ + --nsamples_perpair 1 ; """ parser = argparse.ArgumentParser(description='Create text-context pairs for DPO') parser.add_argument("--challenging_texts", type=str, help="Text file containing challenging texts") @@ -83,8 +88,6 @@ def main(): text_contexts = [text for text in text_contexts if text.strip() != ''] all_records = [] - dummy_audio_filepath = audio_contexts[0]['context_audio_filepath'] - dummy_target_audio_codes_path = audio_contexts[0].get('context_audio_codes_path', None) for challenging_text in challenging_texts: for _ in range(args.n_audio_contexts_per_challenging_text): audio_context = random.choice(audio_contexts) @@ -93,9 +96,7 @@ def main(): for _ in range(args.n_text_contexts_per_challenging_text): text_context = random.choice(text_contexts) - record = create_text_context_record( - challenging_text, text_context, dummy_audio_filepath, 'challenging', dummy_target_audio_codes_path - ) + record = create_text_context_record(challenging_text, text_context, 'challenging') all_records.append(record) for regular_text in regular_texts_for_audiocontext: @@ -107,9 +108,7 @@ def main(): for regular_text in regular_texts_for_textcontext: for _ in range(args.n_text_contexts_per_regular_text): text_context = random.choice(text_contexts) - record = create_text_context_record( - regular_text, text_context, dummy_audio_filepath, 'regular', dummy_target_audio_codes_path - ) + record = create_text_context_record(regular_text, text_context, 'regular') all_records.append(record) random.shuffle(all_records) @@ -150,30 +149,29 @@ def create_audio_context_record(text, audio_context, record_type): return record - -def create_text_context_record(text, text_context, dummy_audio_filepath, record_type, target_audio_codes_path=None): +def create_text_context_record(text, text_context, record_type): """ Creates a record for a text-context pair with text context. Args: text (str): The main text content. text_context (str): The associated text context. - dummy_audio_filepath (str): A placeholder audio file path. record_type (str): Type of record ('challenging' or 'regular'). - target_audio_codes_path (str, optional): Optional target audio codes path. Returns: dict: A dictionary representing the text context record. """ + if text_context.endswith("\n"): + text_context = text_context[:-1] record = { - 'text': text, - 'duration': 6.0, # Does not matter, avoids filtering out in DPO, - 'audio_filepath': dummy_audio_filepath, - 'context_text': text_context, - 'record_type': record_type, # challenging or regular + 'text' : text, + 'duration' : 6.0, # Does not matter, avoids filtering out in DPO, + 'audio_filepath': text_context.split(",")[1], + 'context_text' : text_context.split(",")[0], + 'record_type' : record_type # challenging or regular } - if target_audio_codes_path is not None: - record['target_audio_codes_path'] = target_audio_codes_path + if text_context.split(",")[-1].endswith(".pt"): + record['target_audio_codes_path'] = text_context.split(",")[-1] return record diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index 3e993000bf78..b44c1601456a 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -65,6 +65,11 @@ 'feature_dir' : '/Data/LibriTTS', 'tokenizer_names': ['chartokenizer'], }, + 'grpo_valset': { + 'manifest_path' : '/Data/DPOPairsInputDatav2/text_context_pairs_grpo_val_unseenspeakers.json', + 'audio_dir' : '/', + 'feature_dir' : '/', + }, 'libri_unseen_test_shehzeen_sp': { 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', 'audio_dir' : '/Data/LibriTTS', diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 917c6a8fcc8d..10837d0c4407 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -39,7 +39,9 @@ def run_inference( use_cfg, cfg_scale, batch_size, - num_repeats=1, + sv_model, + asr_model_name, + num_repeats=1 apply_attention_prior=False, attention_prior_epsilon=1e-3, attention_prior_lookahead_window=10, @@ -68,7 +70,7 @@ def run_inference( # Load weights from checkpoint file print("Loading weights from checkpoint") - ckpt = torch.load(checkpoint_file) + ckpt = torch.load(checkpoint_file, weights_only=False) model.load_state_dict(ckpt['state_dict']) print("Loaded weights.") model.cuda() @@ -76,7 +78,7 @@ def run_inference( # import ipdb; ipdb.set_trace() checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] - checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}_LT_{}".format( + checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}_LT_{}_sv_{}".format( checkpoint_name, temperature, topk, @@ -88,7 +90,8 @@ def run_inference( start_prior_after_n_audio_steps, "".join([str(l) for l in estimate_alignment_from_layers]) if estimate_alignment_from_layers is not None else "None", "".join([str(l) for l in apply_prior_to_layers]) if apply_prior_to_layers is not None else "None", - use_local_transformer + use_local_transformer, + sv_model ) dataset_meta_info = evalset_config.dataset_meta_info for dataset in datasets: @@ -206,6 +209,8 @@ def run_inference( dataset_meta[dataset]['audio_dir'], pred_audio_dir, language=language, + sv_model_type=sv_model, + asr_model_name=asr_model_name, ) metrics_n_repeated.append(metrics) with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: @@ -264,7 +269,9 @@ def main(): parser.add_argument('--apply_prior_to_layers', type=str, default=None) parser.add_argument('--start_prior_after_n_audio_steps', type=int, default=10) parser.add_argument('--topk', type=int, default=80) - parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--sv_model', type=str, default="titanet") # titanet, wavlm + parser.add_argument('--asr_model_name', type=str, default="stt_en_conformer_transducer_large") # stt_en_conformer_transducer_large, nvidia/parakeet-ctc-0.6b parser.add_argument('--num_repeats', type=int, default=1) parser.add_argument('--confidence_level', type=float, default=0.95) args = parser.parse_args() @@ -294,6 +301,8 @@ def main(): use_cfg=args.use_cfg, cfg_scale=args.cfg_scale, batch_size=args.batch_size, + sv_model=args.sv_model, + asr_model_name=args.asr_model_name, num_repeats=args.num_repeats, apply_attention_prior=args.apply_attention_prior, attention_prior_epsilon=args.attention_prior_epsilon, @@ -354,6 +363,8 @@ def main(): use_cfg=args.use_cfg, cfg_scale=args.cfg_scale, batch_size=args.batch_size, + sv_model=args.sv_model, + asr_model_name=args.asr_model_name, num_repeats=args.num_repeats, apply_attention_prior=args.apply_attention_prior, attention_prior_epsilon=args.attention_prior_epsilon, diff --git a/scripts/t5tts/evaluate_generated_audio.py b/scripts/t5tts/evaluate_generated_audio.py new file mode 100644 index 000000000000..ecf75658a1d9 --- /dev/null +++ b/scripts/t5tts/evaluate_generated_audio.py @@ -0,0 +1,253 @@ +import argparse +import json +import os +import pprint +import string + +import torch + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.metrics.wer import word_error_rate_detail +from transformers import WhisperProcessor, WhisperForConditionalGeneration +import librosa +import evalset_config +from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector + +def find_sample_audios(audio_dir): + file_list = [] + for f in os.listdir(audio_dir): + if "predicted_audio" in f and f.endswith(".wav"): + audio_number = int(f.split("_")[-1].split(".wav")[0]) + file_list.append((audio_number, os.path.join(audio_dir, f))) + file_list.sort() + file_list = [t[1] for t in file_list] + return file_list + +def read_manifest(manifest_path): + records = [] + with open(manifest_path, 'r') as f: + all_lines = f.readlines() + for line in all_lines: + line = line.strip() + records.append(json.loads(line)) + return records + +def process_text(input_text): + # Convert text to lowercase + lower_case_text = input_text.lower() + + # Remove commas from text + no_comma_text = lower_case_text.replace(",", "") + + # Replace "-" with spaces + no_dash_text = no_comma_text.replace("-", " ") + + # Replace double spaces with single space + single_space_text = " ".join(no_dash_text.split()) + + single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) + + return single_space_text + +def transcribe_with_whisper(whisper_model, whisper_processor, audio_path, language, device): + speech_array, sampling_rate = librosa.load(audio_path, sr=16000) + # Set the language task (optional, improves performance for specific languages) + forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language) if language else None + inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features + inputs = inputs.to(device) + # Generate transcription + with torch.no_grad(): + predicted_ids = whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) + + # Decode transcription + transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) + result = transcription[0] + return result + +def extract_embedding(model, extractor, audio_path, device, sv_model_type): + speech_array, sampling_rate = librosa.load(audio_path, sr=16000) + + if sv_model_type == "wavlm": + inputs = extractor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_values.to(device) + with torch.no_grad(): + embeddings = model(inputs).embeddings + else: # Titanet + with torch.no_grad(): + embeddings = model.get_embedding(audio_path).squeeze() + + return embeddings.squeeze() + +def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large"): + audio_file_lists = find_sample_audios(generated_audio_dir) + records = read_manifest(manifest_path) + assert len(audio_file_lists) == len(records) + + device = "cuda" + + if language == "en": + if asr_model_name == "stt_en_conformer_transducer_large": + asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="stt_en_conformer_transducer_large") + elif asr_model_name == "nvidia/parakeet-ctc-0.6b": + asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b") + + # asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="nvidia/parakeet-tdt-1.1b") + asr_model = asr_model.to(device) + asr_model.eval() + else: + whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + whisper_model = whisper_model.to(device) + whisper_model.eval() + + if sv_model_type == "wavlm": + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-base-plus-sv') + speaker_verification_model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv').to(device).eval() + else: + feature_extractor = None + speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + speaker_verification_model = speaker_verification_model.to(device) + speaker_verification_model.eval() + + speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_small') + speaker_verification_model_alternate = speaker_verification_model_alternate.to(device) + speaker_verification_model_alternate.eval() + + + + filewise_metrics = [] + pred_texts = [] + gt_texts = [] + gt_audio_texts = [] + for ridx, record in enumerate(records): + gt_audio_filepath = record['audio_filepath'] + context_audio_filepath = record.get('context_audio_filepath', None) + if audio_dir is not None: + gt_audio_filepath = os.path.join(audio_dir, gt_audio_filepath) + if context_audio_filepath is not None: + context_audio_filepath = os.path.join(audio_dir, context_audio_filepath) + + pred_audio_filepath = audio_file_lists[ridx] + if language == "en": + with torch.no_grad(): + if asr_model_name == "stt_en_conformer_transducer_large": + pred_text = asr_model.transcribe([pred_audio_filepath])[0][0] + gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0][0] + else: + pred_text = asr_model.transcribe([pred_audio_filepath])[0] + gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0] + + pred_text = process_text(pred_text) + gt_audio_text = process_text(gt_audio_text) + else: + pred_text = transcribe_with_whisper(whisper_model, whisper_processor, pred_audio_filepath, language, device) + pred_text = process_text(pred_text) + gt_audio_text = transcribe_with_whisper(whisper_model, whisper_processor, gt_audio_filepath, language, device) + gt_audio_text = process_text(gt_audio_text) + + if 'normalized_text' in record: + gt_text = process_text(record['normalized_text']) + else: + gt_text = process_text(record['text']) + + detailed_cer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=True) + detailed_wer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=False) + + print("{} GT Text:".format(ridx), gt_text) + print("{} Pr Text:".format(ridx), pred_text) + # Format cer and wer to 2 decimal places + print("CER:", "{:.4f} | WER: {:.4f}".format(detailed_cer[0], detailed_wer[0])) + + pred_texts.append(pred_text) + gt_texts.append(gt_text) + gt_audio_texts.append(gt_audio_text) + + pred_context_ssim = 0.0 + gt_context_ssim = 0.0 + with torch.no_grad(): + gt_speaker_embedding = extract_embedding(speaker_verification_model, feature_extractor, gt_audio_filepath, device, sv_model_type) + pred_speaker_embedding = extract_embedding(speaker_verification_model, feature_extractor, pred_audio_filepath, device, sv_model_type) + pred_gt_ssim = torch.nn.functional.cosine_similarity(gt_speaker_embedding, pred_speaker_embedding, dim=0).item() + + gt_speaker_embedding_alternate = speaker_verification_model_alternate.get_embedding(gt_audio_filepath).squeeze() + pred_speaker_embedding_alternate = speaker_verification_model_alternate.get_embedding(pred_audio_filepath).squeeze() + pred_gt_ssim_alternate = torch.nn.functional.cosine_similarity(gt_speaker_embedding_alternate, pred_speaker_embedding_alternate, dim=0).item() + + if context_audio_filepath is not None: + context_speaker_embedding = extract_embedding(speaker_verification_model, feature_extractor, context_audio_filepath, device, sv_model_type) + context_speaker_embedding_alternate = speaker_verification_model_alternate.get_embedding(context_audio_filepath).squeeze() + + pred_context_ssim = torch.nn.functional.cosine_similarity(pred_speaker_embedding, context_speaker_embedding, dim=0).item() + gt_context_ssim = torch.nn.functional.cosine_similarity(gt_speaker_embedding, context_speaker_embedding, dim=0).item() + + pred_context_ssim_alternate = torch.nn.functional.cosine_similarity(pred_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() + gt_context_ssim_alternate = torch.nn.functional.cosine_similarity(gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() + + + + filewise_metrics.append({ + 'gt_text': gt_text, + 'pred_text': pred_text, + 'gt_audio_text': gt_audio_text, + 'detailed_cer': detailed_cer, + 'detailed_wer': detailed_wer, + 'cer': detailed_cer[0], + 'wer': detailed_wer[0], + 'pred_gt_ssim': pred_gt_ssim, + 'pred_context_ssim': pred_context_ssim, + 'gt_context_ssim': gt_context_ssim, + 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, + 'pred_context_ssim_alternate': pred_context_ssim_alternate, + 'gt_context_ssim_alternate': gt_context_ssim_alternate, + 'gt_audio_filepath': gt_audio_filepath, + 'pred_audio_filepath': pred_audio_filepath, + 'context_audio_filepath': context_audio_filepath + }) + + filewise_metrics_keys_to_save = ['cer', 'wer', 'pred_context_ssim', 'pred_text', 'gt_text', 'gt_audio_filepath', 'pred_audio_filepath', 'context_audio_filepath'] + filtered_filewise_metrics = [] + for m in filewise_metrics: + filtered_filewise_metrics.append({k: m[k] for k in filewise_metrics_keys_to_save}) + + # Sort filewise metrics by cer in reverse + filewise_metrics.sort(key=lambda x: x['cer'], reverse=True) + + avg_metrics = {} + avg_metrics['cer_filewise_avg'] = sum([m['detailed_cer'][0] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['wer_filewise_avg'] = sum([m['detailed_wer'][0] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['cer_cumulative'] = word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=True)[0] + avg_metrics['wer_cumulative'] = word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=False)[0] + avg_metrics['ssim_pred_gt_avg'] = sum([m['pred_gt_ssim'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['ssim_pred_context_avg'] = sum([m['pred_context_ssim'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['ssim_gt_context_avg'] = sum([m['gt_context_ssim'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['ssim_pred_gt_avg_alternate'] = sum([m['pred_gt_ssim_alternate'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['ssim_pred_context_avg_alternate'] = sum([m['pred_context_ssim_alternate'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['ssim_gt_context_avg_alternate'] = sum([m['gt_context_ssim_alternate'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics["cer_gt_audio_cumulative"] = word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=True)[0] + avg_metrics["wer_gt_audio_cumulative"] = word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=False)[0] + + pprint.pprint(avg_metrics) + + return avg_metrics, filewise_metrics + +def main(): + # audio_dir="/datap/misc/Datasets/riva" \ + parser = argparse.ArgumentParser(description='Evaluate Generated Audio') + parser.add_argument('--manifest_path', type=str, default=None) + parser.add_argument('--audio_dir', type=str, default=None) + parser.add_argument('--generated_audio_dir', type=str, default=None) + parser.add_argument('--whisper_language', type=str, default="en") + parser.add_argument('--evalset', type=str, default=None) + args = parser.parse_args() + + if args.evalset is not None: + dataset_meta_info = evalset_config.dataset_meta_info + assert args.evalset in dataset_meta_info + args.manifest_path = dataset_meta_info[args.evalset]['manifest_path'] + args.audio_dir = dataset_meta_info[args.evalset]['audio_dir'] + + evaluate(args.manifest_path, args.audio_dir, args.generated_audio_dir, args.whisper_language, sv_model_type="wavlm", asr_model_name="nvidia/parakeet-ctc-0.6b") + + + +if __name__ == "__main__": + main() From c94d3566ebfe5a3e5eb79c3c6bdc8ac7fdbac227 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Thu, 13 Mar 2025 17:16:51 -0700 Subject: [PATCH 007/113] bug fixes after merge Signed-off-by: Shehzeen Hussain --- nemo/collections/tts/models/t5tts_preference_optimization.py | 4 ++-- scripts/magpietts/infer_and_evaluate.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/models/t5tts_preference_optimization.py b/nemo/collections/tts/models/t5tts_preference_optimization.py index bd19fe3dbe88..6f205e590ba1 100644 --- a/nemo/collections/tts/models/t5tts_preference_optimization.py +++ b/nemo/collections/tts/models/t5tts_preference_optimization.py @@ -51,7 +51,7 @@ def test_step(self, batch, batch_idx): topk = self.cfg.get('inference_topk', 80) use_cfg = self.cfg.get('inference_use_cfg', False) cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) - predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens = self.infer_batch( + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( batch, max_decoder_steps=self.cfg.get('max_decoder_steps', 500), temperature=temperature, @@ -466,7 +466,7 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): use_cfg = random.random() < self.cfg.inference_cfg_prob cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) print("use_cfg", use_cfg) - predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens = self.infer_batch( + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( batch_repeated, max_decoder_steps=self.cfg.get('max_decoder_steps', 430), temperature=temperature, diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 10837d0c4407..2e0f7217e8c6 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -41,7 +41,7 @@ def run_inference( batch_size, sv_model, asr_model_name, - num_repeats=1 + num_repeats=1, apply_attention_prior=False, attention_prior_epsilon=1e-3, attention_prior_lookahead_window=10, From 8f6a22e930771e2894d66b250ce58e4969bc1d4e Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 18 Mar 2025 13:41:03 -0400 Subject: [PATCH 008/113] Cleanup 2503 to include dev files and fixes naming (#45) * add back missing dev files Signed-off-by: Jason * more bug fixes from merge Signed-off-by: Jason * add latest changes for rc5 docker Signed-off-by: Jason --------- Signed-off-by: Jason --- .../tts/conf/magpietts/magpietts_lhotse.yaml | 14 +- examples/tts/magpietts.py | 8 +- .../tts/data/text_to_speech_dataset_lhotse.py | 316 +++++++++++++++++ nemo/collections/tts/models/__init__.py | 11 +- nemo/collections/tts/models/magpietts.py | 19 +- ...y => magpietts_preference_optimization.py} | 104 +++--- scripts/magpietts/codec_extraction.py | 255 ++++++++++++++ .../magpietts/dpo/create_text_contextpairs.py | 2 +- scripts/magpietts/eval_squimmos.py | 67 ++++ scripts/magpietts/evalset_config.py | 6 +- .../evaluate_generated_audio.py | 52 ++- scripts/magpietts/infer_and_evaluate.py | 88 ++--- scripts/tts_dataset_to_lhotse/README.md | 82 +++++ scripts/tts_dataset_to_lhotse/create_shars.py | 167 +++++++++ t5tts_inference_multiturndialogues.ipynb | 320 ++++++++++++++++++ 15 files changed, 1359 insertions(+), 152 deletions(-) create mode 100644 nemo/collections/tts/data/text_to_speech_dataset_lhotse.py rename nemo/collections/tts/models/{t5tts_preference_optimization.py => magpietts_preference_optimization.py} (97%) create mode 100644 scripts/magpietts/codec_extraction.py create mode 100644 scripts/magpietts/eval_squimmos.py rename scripts/{t5tts => magpietts}/evaluate_generated_audio.py (94%) create mode 100644 scripts/tts_dataset_to_lhotse/README.md create mode 100644 scripts/tts_dataset_to_lhotse/create_shars.py create mode 100644 t5tts_inference_multiturndialogues.ipynb diff --git a/examples/tts/conf/magpietts/magpietts_lhotse.yaml b/examples/tts/conf/magpietts/magpietts_lhotse.yaml index 26d58cac63e3..17b3597a9a3f 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse.yaml @@ -1,4 +1,4 @@ -name: T5TTS +name: Magpie-TTS-Lhotse max_steps: ??? limit_val_batches: ??? @@ -50,7 +50,7 @@ model: alignment_encoder_loss_scale: 1.0 binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. - prior_future_context: 2 # Future window of the binarized prior. + prior_future_context: 2 # Future window of the binarized prior. prior_past_context: 2 # Past window of the binarized prior. prior_future_decay: 0.8 # Decay factor for future context prior_past_decay: 0.5 # Decay factor for past context @@ -176,8 +176,8 @@ model: use_bucketing: false is_tarred: false batch_size: ${eval_batch_size} - - t5_encoder: + + encoder: n_layers: 6 d_model: 768 d_ffn: 3072 @@ -190,7 +190,7 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true - + context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise n_layers: 3 d_model: 768 @@ -205,7 +205,7 @@ model: max_length_causal_mask: 2048 use_learnable_pos_emb: true - t5_decoder: + decoder: n_layers: 12 d_model: 768 d_ffn: 3072 @@ -259,7 +259,7 @@ exp_manager: wandb_logger_kwargs: name: null project: null - create_checkpoint_callback: true + create_checkpoint_callback: true checkpoint_callback_params: monitor: val_loss mode: min diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index a9de63ca5582..7f745ced2b6e 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -15,7 +15,7 @@ import pytorch_lightning as pl from omegaconf import OmegaConf, open_dict -from nemo.collections.tts.models import MagpieTTS_Model, MagpieTTS_ModelInference, T5TTS_Model_PrefDataGen, T5TTS_Model_OfflinePO, T5TTS_Model_OnlinePO +from nemo.collections.tts.models import MagpieTTS_Model, MagpieTTS_ModelInference, MagpieTTS_Model_PrefDataGen, MagpieTTS_Model_OfflinePO, MagpieTTS_Model_OnlinePO from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @@ -38,16 +38,16 @@ def main(cfg): model_cfg = cfg.model with open_dict(model_cfg): model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt - model = T5TTS_Model_OfflinePO(cfg=model_cfg, trainer=trainer) + model = MagpieTTS_Model_OfflinePO(cfg=model_cfg, trainer=trainer) elif cfg.get('mode', 'train') == 'onlinepo_train': model_cfg = cfg.model with open_dict(model_cfg): model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt - model = T5TTS_Model_OnlinePO(cfg=model_cfg, trainer=trainer) + model = MagpieTTS_Model_OnlinePO(cfg=model_cfg, trainer=trainer) elif cfg.get('mode', 'train') == 'test': model = MagpieTTS_ModelInference(cfg=cfg.model, trainer=trainer) # elif cfg.get('mode', 'train') == 'test': - # model = T5TTS_Model_PrefDataGen(cfg=cfg.model, trainer=trainer) + # model = MagpieTTS_Model_PrefDataGen(cfg=cfg.model, trainer=trainer) else: raise NotImplementedError(f"Only train, dpo_train and test modes are supported. Got {cfg.mode}") diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py new file mode 100644 index 000000000000..0a6b8e9f66b7 --- /dev/null +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -0,0 +1,316 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from pathlib import Path +from typing import List, Optional + +import librosa +import torch +import random +from lhotse.dataset.collation import collate_vectors as collate_vectors_lhotse +from megatron.core import parallel_state +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.tts.parts.utils.tts_dataset_utils import beta_binomial_prior_distribution, stack_tensors +from nemo.utils import logging +from nemo.utils.decorators import experimental + + +def collate_vectors(items, max_length: int, padding_value): + vectors = collate_vectors_lhotse(items, padding_value=padding_value) + if max_length > vectors.size(1): + vectors = torch.cat( + [vectors, padding_value * torch.ones(vectors.size(0), max_length - vectors.size(1), dtype=vectors.dtype)], + dim=1, + ) + if items[0].shape[0] < 1: + vectors = vectors.long() + return vectors + +def normalize_volume_torch(audio, volume_level: float = 0.95): + """Apply peak normalization to the input audio. + """ + if not (0.0 <= volume_level <= 1.0): + raise ValueError(f"Volume must be in range [0.0, 1.0], received {volume_level}") + + if audio.size == 0: + return audio + + max_sample = torch.max(torch.abs(audio)) + if max_sample == 0: + return audio + + return volume_level * (audio / torch.max(torch.abs(audio))) + +def build_lhotse_dataloader(dataset, data_cfg, is_eval=False): + """Buld dataloader given an input dataset.""" + return get_lhotse_dataloader_from_config( + data_cfg, + global_rank=parallel_state.get_data_parallel_rank(), + world_size=parallel_state.get_data_parallel_world_size(), + dataset=dataset, + ) + + +class MagpieTTSLhotseDataset(torch.utils.data.Dataset): + """ + Class for processing and loading text to speech training examples. + + Args: + sample_rate: Sample rate to load audio as. If the audio is stored at a different sample rate, then it will + be resampled. + text_tokenizer: Tokenizer to apply to the text field. + speaker_path: Optional, path to JSON file with speaker indices, for multi-speaker training. Can be created with + scripts.dataset_processing.tts.create_speaker_map.py + featurizers: Optional, list of featurizers to load feature data from. Should be the same config provided + when running scripts.dataset_processing.tts.compute_features.py before training. + feature_processors: Optional, list of feature processors to run on training examples. + align_prior_hop_length: Optional int, hop length of audio features. + If provided alignment prior will be calculated and included in batch output. Must match hop length + of audio features used for training. + min_duration: Optional float, if provided audio files in the training manifest shorter than 'min_duration' + will be ignored. + max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration' + will be ignored. + volume_norm: Whether to apply volume normalization to loaded audio. + """ + def __init__( + self, + sample_rate: int, + align_prior_hop_length: Optional[int] = None, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + volume_norm: bool = True, + codec_model_downsample_factor: int = None, + bos_id: int = None, + eos_id: int = None, + audio_bos_id: int = None, + audio_eos_id: int = None, + prior_scaling_factor: float = None, + load_cached_codes_if_available: bool = True, + dataset_type: str = 'train', + tokenizer_config=None, + load_16khz_audio: bool = True, + use_text_conditioning_tokenizer: bool = False, + pad_context_text_to_max_duration: bool = False, + context_duration_min: float = 3.0, + context_duration_max: float = 10.0 + ): + super().__init__() + self.sample_rate = sample_rate + self.text_tokenizer = None + self.align_prior_hop_length = align_prior_hop_length + self.volume_norm = volume_norm + + self.bos_id = bos_id + self.eos_id = eos_id + self.audio_bos_id = audio_bos_id + self.audio_eos_id = audio_eos_id + self.codec_model_downsample_factor = codec_model_downsample_factor + self.include_align_prior = prior_scaling_factor is not None + self.prior_scaling_factor = prior_scaling_factor + self.load_cached_codes_if_available = load_cached_codes_if_available + self.dataset_type = dataset_type + self.tokenizer_config = tokenizer_config + self.load_16khz_audio = load_16khz_audio + self.use_text_conditioning_tokenizer = use_text_conditioning_tokenizer + self.text_conditioning_tokenizer = None + self.pad_context_text_to_max_duration = pad_context_text_to_max_duration + self.context_duration_min = context_duration_min + self.context_duration_max = context_duration_max + + def __getitem__(self, cuts): + cuts = cuts.sort_by_duration() + + logging.debug(f"Len: {len(cuts)}") + + # load audios and text + num_codec_frames = [] + align_priors = [] + context_audios = [] + context_audios_lens = [] + target_audios = [] + target_audios_lens = [] + target_audios_16khz = [] + target_audios_16khz_lens = [] + context_text_tokens = [] + context_text_tokens_lens = [] + has_text_context_list = [] + target_text_tokens = [] + target_text_tokens_lens = [] + + for i, cut in enumerate(cuts): + # load target/answer audio + answer_audio = torch.FloatTensor(cut.target_audio.resample(self.sample_rate).load_audio()).squeeze(0) + if self.volume_norm: + answer_audio = normalize_volume_torch(answer_audio) + + answer_audio = torch.nn.functional.pad( + answer_audio, + (0, self.codec_model_downsample_factor - (answer_audio.shape[0] % self.codec_model_downsample_factor)), + value=0 + ).unsqueeze(0) + + answer_audio_len = answer_audio.shape[1] + target_audios.append(answer_audio) + target_audios_lens.append(answer_audio_len) + num_frames = int(answer_audio_len / self.codec_model_downsample_factor) + 1 # +1 for EOS + num_codec_frames.append(num_frames) + + # load context audio + context_audio = torch.FloatTensor(cut.resample(self.sample_rate).load_audio()).squeeze(0) + if self.volume_norm: + context_audio = normalize_volume_torch(context_audio) + + context_audio = torch.nn.functional.pad( + context_audio, + (0, self.codec_model_downsample_factor - (context_audio.shape[0] % self.codec_model_downsample_factor)), + value=0 + ).unsqueeze(0) + context_audios_len = context_audio.shape[1] + context_audios.append(context_audio) + context_audios_lens.append(context_audios_len) + + # load context text + if cut.supervisions[0].speaker == "user": + if self.use_text_conditioning_tokenizer: + context_text = cut.supervisions[0].text + context_tokenizer = self.text_conditioning_tokenizer if self.text_conditioning_tokenizer else self.text_tokenizer + # check if the text is not empty + if context_text.replace(" ", ""): + context_text = self.text_conditioning_tokenizer(context_text)['input_ids'] + has_text_context_list.append(True) + else: + context_text = self.text_conditioning_tokenizer("[NO TEXT CONTEXT]")['input_ids'] + has_text_context_list.append(False) + + if self.pad_context_text_to_max_duration: + _required_len = int(self.context_duration_max * self.sample_rate / self.codec_model_downsample_factor) + 2 # +2 for BOS and EOS + if len(context_text) < _required_len: + _pad_id = self.text_conditioning_tokenizer.pad_token_id + context_text += [_pad_id] * (_required_len - len(context_text)) + else: + context_text = context_text[:_required_len] + + context_text = torch.tensor(context_text, dtype=torch.int32) + context_text_len = context_text.shape[0] + context_text_tokens.append(context_text) + context_text_tokens_lens.append(context_text_len) + else: + raise Exception("First speaker should be user") + + if cut.supervisions[1].speaker == "agent": + target_text = cut.supervisions[1].text + # check if the text is not empty + if target_text.replace(" ", ""): + tokenizer_name = "english_phoneme" # Default to english phoneme tokenizer + if getattr(cut, "tokenizer_names", None): + # Pick a random tokenizer from the list of tokenizers + tokenizer_name = random.choice(cut.tokenizer_names) + + target_text = self.text_tokenizer.encode(text=target_text, tokenizer_name=tokenizer_name) + target_text = target_text + [self.eos_id] + else: + target_text = [self.eos_id] + + target_text = torch.tensor(target_text, dtype=torch.int32) + target_text_len = target_text.shape[0] + target_text_tokens.append(target_text) + target_text_tokens_lens.append(target_text_len) + else: + raise Exception("Second speaker should be agent") + + if self.include_align_prior: + # align_prior = self.beta_binomial_interpolator(spec_len, text_len) + align_prior = beta_binomial_prior_distribution(phoneme_count=target_text_len, mel_count=num_frames, scaling_factor=self.prior_scaling_factor) + align_prior = torch.tensor(align_prior, dtype=torch.float32) + align_priors.append(align_prior) + + if self.load_16khz_audio: + target_audio_16khz = librosa.resample(answer_audio.squeeze(0).numpy(), orig_sr=self.sample_rate, target_sr=16000) + target_audio_16khz = torch.FloatTensor(target_audio_16khz).unsqueeze(0) + target_audio_16khz_len = target_audio_16khz.shape[1] + target_audios_16khz.append(target_audio_16khz) + target_audios_16khz_lens.append(target_audio_16khz_len) + + # collate target/agent audios + target_audios = collate_vectors( + [a.squeeze(0) for a in target_audios], max_length=max(target_audios_lens), padding_value=0.0 + ).float() + target_audios_lens = torch.IntTensor(target_audios_lens) + num_codec_frames = torch.IntTensor(num_codec_frames) + + # collate context/user audios + context_audios = collate_vectors( + [a.squeeze(0) for a in context_audios], max_length=max(context_audios_lens), padding_value=0.0 + ).float() + context_audios_lens = torch.IntTensor(context_audios_lens) + + # collate context/user text + if self.use_text_conditioning_tokenizer: + context_text_tokens = collate_vectors(context_text_tokens, max_length=max(context_text_tokens_lens), padding_value=self.text_tokenizer.pad) + context_text_tokens_lens = torch.IntTensor(context_text_tokens_lens) + + # collate target/agent text + target_text_tokens = collate_vectors(target_text_tokens, max_length=max(target_text_tokens_lens), padding_value=self.text_tokenizer.pad) + target_text_tokens_lens = torch.IntTensor(target_text_tokens_lens) + + # collate align prior + if self.include_align_prior: + spec_max_len = max([prior.shape[0] for prior in align_priors]) + text_max_len = max([prior.shape[1] for prior in align_priors]) + align_priors = stack_tensors(align_priors, max_lens=[text_max_len, spec_max_len],) + + # collate 16khz target/agent audio + if self.load_16khz_audio: + target_audios_16khz = collate_vectors( + [a.squeeze(0) for a in target_audios_16khz], max_length=max(target_audios_16khz_lens), padding_value=0.0 + ).float() + target_audios_16khz_lens = torch.IntTensor(target_audios_16khz_lens) + + batch_dict = { + # "dataset_names": dataset_names, + # "audio_filepaths": audio_filepath_list, + "sample_ids": list(cuts.ids), + "text": target_text_tokens, + "text_lens": target_text_tokens_lens, + 'audio': target_audios, + 'audio_lens': target_audios_lens, + # 'audio_codes': batch_audio_codes + # 'audio_codes_lens': batch_audio_codes_len + 'context_audio': context_audios, + 'context_audio_lens': context_audios_lens, + # 'context_audio_codes': batch_context_audio_codes + # 'context_audio_codes_lens': batch_context_audio_codes_len + } + + if self.include_align_prior: + batch_dict["align_prior_matrix"] = align_priors + + if self.load_16khz_audio: + batch_dict['audio_16khz'] = target_audios_16khz + batch_dict['audio_lens_16khz'] = target_audios_16khz_lens + + if self.use_text_conditioning_tokenizer: + batch_dict['context_text_tokens'] = context_text_tokens + batch_dict['context_text_len'] = context_text_tokens_lens + batch_dict['has_text_context'] = torch.BoolTensor(has_text_context_list) + + return batch_dict + + + def collate_fn(self, batch: List[dict]): + return batch diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index c62800149460..203a662f4246 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -17,12 +17,12 @@ from nemo.collections.tts.models.fastpitch import FastPitchModel from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL from nemo.collections.tts.models.hifigan import HifiGanModel -from nemo.collections.tts.models.magpietts import MagpieTTS_Model, MagpieTTS_ModelDPO, MagpieTTS_ModelInference +from nemo.collections.tts.models.magpietts import MagpieTTS_Model, MagpieTTS_ModelInference +from nemo.collections.tts.models.magpietts_preference_optimization import MagpieTTS_Model_PrefDataGen, MagpieTTS_Model_OfflinePO, MagpieTTS_Model_OnlinePO from nemo.collections.tts.models.mixer_tts import MixerTTSModel from nemo.collections.tts.models.radtts import RadTTSModel from nemo.collections.tts.models.spectrogram_enhancer import SpectrogramEnhancerModel from nemo.collections.tts.models.ssl_tts import SSLDisentangler -from nemo.collections.tts.models.t5tts_preference_optimization import T5TTS_Model_PrefDataGen, T5TTS_Model_OfflinePO, T5TTS_Model_OnlinePO from nemo.collections.tts.models.tacotron2 import Tacotron2Model from nemo.collections.tts.models.two_stages import GriffinLimModel, MelPsuedoInverseModel, TwoStagesModel from nemo.collections.tts.models.univnet import UnivNetModel @@ -42,10 +42,9 @@ "RadTTSModel", "MagpieTTS_Model", "MagpieTTS_ModelInference", - "MagpieTTS_ModelDPO", - "T5TTS_Model_PrefDataGen", - "T5TTS_Model_OfflinePO", - "T5TTS_Model_OnlinePO", + "MagpieTTS_Model_PrefDataGen", + "MagpieTTS_Model_OfflinePO", + "MagpieTTS_Model_OnlinePO", "Tacotron2Model", "TwoStagesModel", "UnivNetModel", diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 719c786fd475..508734b62dba 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -35,7 +35,7 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer -from nemo.collections.tts.data.text_to_speech_dataset_lhotse import build_lhotse_dataloader, T5TTSLhotseDataset +from nemo.collections.tts.data.text_to_speech_dataset_lhotse import build_lhotse_dataloader, MagpieTTSLhotseDataset from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 @@ -176,11 +176,11 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.final_proj = nn.Linear(cfg.decoder.d_model, cfg.num_audio_codebooks * cfg.num_audio_tokens_per_codebook) if cfg.get('use_local_transformer', False): local_transformer_hidden_dim = cfg.get('local_transformer_hidden_dim', 256) - if local_transformer_hidden_dim != cfg.t5_decoder.d_model: - self.local_transformer_in_projection = nn.Linear(cfg.t5_decoder.d_model, local_transformer_hidden_dim) + if local_transformer_hidden_dim != cfg.decoder.d_model: + self.local_transformer_in_projection = nn.Linear(cfg.decoder.d_model, local_transformer_hidden_dim) else: self.local_transformer_in_projection = nn.Identity() - self.local_transformer = t5tts_transformer.Transformer( + self.local_transformer = transformer_2501.Transformer( n_layers=self.cfg.get('local_transformer_n_layers', 2), d_model=local_transformer_hidden_dim, d_ffn=local_transformer_hidden_dim*4, @@ -731,10 +731,10 @@ def prepare_context_tensors(self, batch): # Convert prior to a list of tensors, one for each layer # Set None for layers not in ctc_prior_layer_ids if self.model_type == 'multi_encoder_context_tts': - text_attn_prior = [attn_prior[0] if layer_idx in ctc_prior_layer_ids else None for layer_idx in range(self.cfg.t5_decoder.n_layers) ] + text_attn_prior = [attn_prior[0] if layer_idx in ctc_prior_layer_ids else None for layer_idx in range(self.cfg.decoder.n_layers) ] attn_prior = [text_attn_prior, attn_prior[1]] else: - attn_prior = [attn_prior if layer_idx in ctc_prior_layer_ids else None for layer_idx in range(self.cfg.t5_decoder.n_layers) ] + attn_prior = [attn_prior if layer_idx in ctc_prior_layer_ids else None for layer_idx in range(self.cfg.decoder.n_layers) ] return { 'beta_binomial_attn_prior': batch.get('align_prior_matrix', None), @@ -1269,7 +1269,7 @@ def infer_batch( _audio_codes_mask = audio_codes_mask if apply_prior_to_layers is not None: - attn_prior = [None for _ in range(self.cfg.t5_decoder.n_layers)] + attn_prior = [None for _ in range(self.cfg.decoder.n_layers)] for layer_idx in apply_prior_to_layers: attn_prior[layer_idx] = _attn_prior else: @@ -1278,6 +1278,7 @@ def infer_batch( if self.model_type == 'multi_encoder_context_tts': attn_prior = [attn_prior, None] + if use_cfg: batch_size = audio_codes_embedded.size(0) if isinstance(context_tensors['cond'], list): @@ -1304,6 +1305,10 @@ def infer_batch( dummy_addition_dec_mask ) + # print(f"step {idx}") + # print(f"use_cfg {use_cfg}") + # print(f"shape {cfg_audio_codes_embedded.shape}") + # print(f"use kv cahce? {self.use_kv_cache_for_inference}") combined_logits, attn_probs, dec_out = self.forward( dec_input_embedded=cfg_audio_codes_embedded, dec_input_mask=cfg_audio_codes_mask, diff --git a/nemo/collections/tts/models/t5tts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py similarity index 97% rename from nemo/collections/tts/models/t5tts_preference_optimization.py rename to nemo/collections/tts/models/magpietts_preference_optimization.py index 6f205e590ba1..f6e9c56eebee 100644 --- a/nemo/collections/tts/models/t5tts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -22,10 +22,10 @@ except ImportError: HAVE_TORCHAUDIO = False -from nemo.collections.tts.models import T5TTS_Model +from nemo.collections.tts.models import MagpieTTS_Model -class T5TTS_Model_PrefDataGen(T5TTS_Model): +class MagpieTTS_Model_PrefDataGen(MagpieTTS_Model): """Small override to save inference metrics, used for datagen in Offline PO""" def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) @@ -79,7 +79,7 @@ def test_step(self, batch, batch_idx): predicted_codes_torch = predicted_codes_torch[:, :predicted_codes_lens[idx]] torch.save(predicted_codes_torch, os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt')) predicted_audio_paths.append(audio_path) - + if not batch_invalid: with torch.no_grad(): try: @@ -114,7 +114,7 @@ def test_step(self, batch, batch_idx): spk_embedding_pred = pred_speaker_embeddings[idx].cpu().numpy() spk_embedding_gt = gt_speaker_embeddings[idx].cpu().numpy() - + spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) ) @@ -138,7 +138,7 @@ def test_step(self, batch, batch_idx): with open(os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w') as f: json.dump(item_metrics, f) -class T5TTS_Model_OfflinePO(T5TTS_Model): +class MagpieTTS_Model_OfflinePO(MagpieTTS_Model): def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) # Copy cfg @@ -146,14 +146,14 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): with open_dict(ref_model_cfg): ref_model_cfg.train_ds = None ref_model_cfg.validation_ds = None - self._reference_model = T5TTS_Model(cfg=ref_model_cfg) + self._reference_model = MagpieTTS_Model(cfg=ref_model_cfg) print("Loading reference model from checkpoint") self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) self.freeze_model(self._reference_model) self._reference_model.eval() self._reference_model._no_state_dict = True print("Reference model loaded and frozen") - + def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = super().state_dict(destination, prefix, keep_vars) keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model'] @@ -161,7 +161,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): if any([substring in key for substring in keys_substrings_to_exclude]): del state_dict[key] return state_dict - + def _get_batch_logps(self, logits, labels, loss_mask, average_log_prob=False): """Compute the log probabilities of the given labels under the given logits. @@ -252,13 +252,13 @@ def preference_loss(self, policy_chosen_logps, def process_batch_dpo(self, batch_chosen_rejected): batch_chosen = batch_chosen_rejected['chosen'] batch_rejected = batch_chosen_rejected['rejected'] - + model_output_chosen = self.process_batch(batch_chosen) model_output_rejected = self.process_batch(batch_rejected) with torch.no_grad(): reference_model_output_chosen = self._reference_model.process_batch(batch_chosen) reference_model_output_rejected = self._reference_model.process_batch(batch_rejected) - + chosen_policy_logprobs = None rejected_policy_logprobs = None chosen_ref_logprobs = None @@ -280,7 +280,7 @@ def process_batch_dpo(self, batch_chosen_rejected): with torch.no_grad(): ref_codebook_log_probs_chosen = self._get_batch_logps(ref_codebook_logits_chosen, codebook_labels_chosen, reference_model_output_chosen['loss_mask']) ref_codebook_log_probs_rejected = self._get_batch_logps(ref_codebook_logits_rejected, codebook_labels_rejected, reference_model_output_rejected['loss_mask']) - + if chosen_policy_logprobs is None: chosen_policy_logprobs = codebook_log_probs_chosen rejected_policy_logprobs = codebook_log_probs_rejected @@ -291,10 +291,10 @@ def process_batch_dpo(self, batch_chosen_rejected): rejected_policy_logprobs += codebook_log_probs_rejected chosen_ref_logprobs += ref_codebook_log_probs_chosen rejected_ref_logprobs += ref_codebook_log_probs_rejected - + rewards_chosen = batch_chosen['rewards'] rewards_rejected = batch_rejected['rewards'] - + assert torch.all(rewards_chosen == 1) assert torch.all(rewards_rejected < 1) @@ -311,7 +311,7 @@ def process_batch_dpo(self, batch_chosen_rejected): pref_loss = pref_loss.mean() sft_loss = -chosen_policy_logprobs.mean() - + pref_loss_weight = self.cfg.get('dpo_pref_loss_weight', 1.0) sft_loss_weight = self.cfg.get('dpo_sft_loss_weight', 0.0) loss = pref_loss_weight * pref_loss + sft_loss * sft_loss_weight @@ -319,7 +319,7 @@ def process_batch_dpo(self, batch_chosen_rejected): alignment_loss = model_output_chosen['alignment_loss'] if alignment_loss is not None: loss += alignment_loss - + return { 'loss': loss, 'pref_loss': pref_loss, @@ -333,22 +333,22 @@ def training_step(self, batch, batch_idx): self.log('train_pref_loss', dpo_outputs['pref_loss'], prog_bar=True, sync_dist=True) self.log('train_sft_loss', dpo_outputs['sft_loss'], prog_bar=True, sync_dist=True) return dpo_outputs['loss'] - + def validation_step(self, batch, batch_idx): dpo_outputs = self.process_batch_dpo(batch) - + val_loss = dpo_outputs['loss'] val_pref_loss = dpo_outputs['pref_loss'] val_sft_loss = dpo_outputs['sft_loss'] val_alignment_loss = dpo_outputs['alignment_loss'] - + self.validation_step_outputs.append({ 'val_loss': val_loss, 'val_pref_loss': val_pref_loss, 'val_sft_loss': val_sft_loss, 'val_alignment_loss': val_alignment_loss, }) - + def on_validation_epoch_end(self): def collect(key): values = [] @@ -371,7 +371,7 @@ def collect(key): self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) self.validation_step_outputs.clear() -class T5TTS_Model_OnlinePO(T5TTS_Model): +class MagpieTTS_Model_OnlinePO(MagpieTTS_Model): def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) # Copy cfg @@ -379,10 +379,10 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): with open_dict(ref_model_cfg): ref_model_cfg.train_ds = None ref_model_cfg.validation_ds = None - + self.reference_free = self.cfg.get('reference_free', False) # True means we dont use the reference model if not self.reference_free: - self._reference_model = T5TTS_Model(cfg=ref_model_cfg) + self._reference_model = MagpieTTS_Model(cfg=ref_model_cfg) print("Loading reference model from checkpoint") self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) self.freeze_model(self._reference_model) @@ -411,22 +411,22 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") self.whisper_model.eval() - + use_pesq = self.cfg.get('use_pesq', False) if use_pesq: # import ipdb; ipdb.set_trace() assert HAVE_TORCHAUDIO, "torchaudio is required for PESQ reward" self.squim_objective_model = SQUIM_OBJECTIVE.get_model() - + def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = super().state_dict(destination, prefix, keep_vars) - keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model', + keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model', 'eval_asr_model', 'eval_speaker_verification_model', 'whisper_model'] for key in list(state_dict.keys()): if any([substring in key for substring in keys_substrings_to_exclude]): del state_dict[key] return state_dict - + def _get_per_token_logps(self, logits, labels, loss_mask): """Compute the log probabilities of the given labels under the given logits. @@ -495,7 +495,7 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): predicted_codes_torch = predicted_codes_torch[:, :predicted_codes_lens[idx]] # C, T torch.save(predicted_codes_torch, os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt')) predicted_audio_paths.append(audio_path) - + with torch.no_grad(): if self.cfg.get("reward_asr_model", "nemo") == "nemo": pred_transcripts = self.eval_asr_model.transcribe(predicted_audio_paths, batch_size=len(predicted_audio_paths)) @@ -507,10 +507,10 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): transcript = transcribe_with_whisper(audio_path, language, self.whisper_processor, self.whisper_model, self.device) pred_transcripts.append(transcript) pred_transcripts = [process_text_for_cer(transcript) for transcript in pred_transcripts] - + pred_speaker_embeddings = get_speaker_embeddings_from_filepaths(predicted_audio_paths, self.eval_speaker_verification_model, self.device) gt_speaker_embeddings = get_speaker_embeddings_from_filepaths(batch_repeated['audio_filepaths'], self.eval_speaker_verification_model, self.device) - + batch_metrics = [] cer_reward_weight = self.cfg.get('cer_reward_weight', 0.5) @@ -552,7 +552,7 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): batch_metrics.append(item_metrics) num_groups = len(batch['audio_filepaths']) - + best_ssim_achievable = self.cfg.get("best_ssim_achievable", 0.9) # Examples with this speaker similarity or higher will have SSIM reward of 1 mean_cer_dataset = self.cfg.get("mean_cer_dataset", 0.1) # CER equal to this value will have reward of 0.5 mean_ssim_dataset = self.cfg.get("mean_ssim_dataset", 0.6) # SSIM equal to this value will have reward of 0.5 @@ -584,7 +584,7 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): pesq_reward = 0.0 batch_metrics[idx]['reward'] = cer_reward * cer_reward_weight + spk_similarity_reward * ssim_reward_weight + pesq_reward * pesq_reward_weight - + if (batch_metrics[idx]['codes_len'] >= 425) or (batch_metrics[idx]['codes_len'] <= 3): # TODO: Remove hardcoded lengths # This means it did not complete the sentence or generated an extremely short sentence batch_metrics[idx]['reward'] = 0.0 @@ -594,7 +594,7 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): batch_metrics[idx]['pesq_reward'] = pesq_reward mean_reward += batch_metrics[idx]['reward'] group_rewards.append(batch_metrics[idx]['reward']) - + mean_reward /= num_generations_per_item std_reward = np.std(group_rewards) for idx in range(group_start_idx, group_end_idx): @@ -612,20 +612,20 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): 'predicted_codes_lens': predicted_codes_lens, 'advantages': advantages, } - + def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): use_kv_cache_during_online_po = self.cfg.get("use_kv_cache_during_online_po", False) if use_kv_cache_during_online_po: self.use_kv_cache_for_inference = True - self.t5_decoder.reset_cache(use_cache=True) - + self.decoder.reset_cache(use_cache=True) + with torch.no_grad(): generated_codes_and_metrics = self.generate_and_reward(batch, n_generations_per_item, mode) - + if use_kv_cache_during_online_po: self.use_kv_cache_for_inference = False - self.t5_decoder.reset_cache(use_cache=False) - + self.decoder.reset_cache(use_cache=False) + batch_repeated = generated_codes_and_metrics['batch_repeated'] predicted_codes = generated_codes_and_metrics['predicted_codes'] # B, 8, T predicted_codes_lens = generated_codes_and_metrics['predicted_codes_lens'] # B @@ -644,13 +644,13 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): del batch_repeated['audio'] if 'audio_lens' in batch_repeated: del batch_repeated['audio_lens'] - + policy_model_outputs = self.process_batch(batch_repeated) if not self.reference_free: with torch.no_grad(): reference_model_output = self._reference_model.process_batch(batch_repeated) - + total_loss = None total_kl = None for codebook_idx in range(self.cfg.num_audio_codebooks): @@ -660,7 +660,7 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): codebook_labels = batch_repeated['audio_codes'][:,codebook_idx,1:] per_token_codebook_log_probs = self._get_per_token_logps(codebook_logits, codebook_labels, policy_model_outputs['loss_mask']) per_token_loss = -(torch.exp(per_token_codebook_log_probs - per_token_codebook_log_probs.detach()) * advantages.unsqueeze(1)) - + if not self.reference_free: with torch.no_grad(): @@ -672,7 +672,7 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): codebook_kl_loss_mean = ((per_token_codebook_kl * policy_model_outputs['loss_mask']).sum(dim=1) / policy_model_outputs['loss_mask'].sum(dim=1)).mean() else: codebook_kl_loss_mean = torch.tensor(0.0, device=self.device) - + codebook_loss = ((per_token_loss * policy_model_outputs['loss_mask']).sum(dim=1) / policy_model_outputs['loss_mask'].sum(dim=1)).mean() if total_loss is None: @@ -682,7 +682,7 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): total_loss += codebook_loss total_kl += codebook_kl_loss_mean - + total_loss /= self.cfg.num_audio_codebooks print("Total kl", total_kl, n_generations_per_item) return { @@ -700,21 +700,21 @@ def training_step(self, batch, batch_idx): self.log('train_kl_loss', po_outputs['kl_loss'], prog_bar=True, sync_dist=True) self.log('train_mean_reward', po_outputs['mean_reward'], prog_bar=True, sync_dist=True) return po_outputs['loss'] - + def validation_step(self, batch, batch_idx): po_outputs = self.process_batch_online_po(batch, 1, mode='val') batch_metrics = po_outputs['batch_metrics'] mean_reward = po_outputs['mean_reward'] val_loss = po_outputs['loss'] val_kl_loss = po_outputs['kl_loss'] - + self.validation_step_outputs.append({ 'mean_reward': mean_reward, 'val_loss': val_loss, 'val_kl_loss': val_kl_loss, 'batch_metrics': batch_metrics, }) - + def on_validation_epoch_end(self): def collect(key): values = [] @@ -743,7 +743,7 @@ def collect(key): if key not in mean_metrics: mean_metrics[key] = [] mean_metrics[key].append(value) - + for key, values in mean_metrics.items(): mean_metrics[key] = np.mean(values) self.log(f"val_{key}", mean_metrics[key], prog_bar=True, sync_dist=True) @@ -760,7 +760,7 @@ def process_text_for_cer(input_text): """ # Convert text to lowercase lower_case_text = input_text.lower() - + # Remove commas from text no_comma_text = lower_case_text.replace(",", "") # Replace "-" with spaces @@ -768,7 +768,7 @@ def process_text_for_cer(input_text): no_dash_text = no_dash_text.replace("'", "") no_dash_text = no_dash_text.replace(";", "") no_dash_text = no_dash_text.replace(".", "") - + # Replace double spaces with single space single_space_text = " ".join(no_dash_text.split()) @@ -790,16 +790,16 @@ def get_speaker_embeddings_from_filepaths(filepaths, speaker_verification_model, audio_tensor = torch.tensor(audio, dtype=torch.float32, device=device) audio_batch.append(audio_tensor) audio_lengths.append(audio_tensor.size(0)) - + batch_audio_lens = torch.tensor(audio_lengths, device=device).long() max_audio_len = int(batch_audio_lens.max().item()) audio_batch = stack_tensors(audio_batch, max_lens=[max_audio_len]) _, speaker_embeddings = speaker_verification_model.forward( - input_signal=audio_batch, + input_signal=audio_batch, input_signal_length=batch_audio_lens ) - + return speaker_embeddings def transcribe_with_whisper(audio_filepath, language, whisper_processor, whisper_model, device): diff --git a/scripts/magpietts/codec_extraction.py b/scripts/magpietts/codec_extraction.py new file mode 100644 index 000000000000..1cfd90d798fa --- /dev/null +++ b/scripts/magpietts/codec_extraction.py @@ -0,0 +1,255 @@ +import json +import torch +from torch.utils.data import Dataset, DataLoader +import pytorch_lightning as pl +from pytorch_lightning import Trainer +from pytorch_lightning.strategies import DDPStrategy +from nemo.collections.tts.models import AudioCodecModel +import os +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +import argparse +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +class AudioDataset(Dataset): + def __init__(self, file_lists, base_audio_dirs, dataset_names, out_dir, sample_rate=22050, pad_multiple=1024): + self.file_list = file_lists + self.base_audio_dirs = base_audio_dirs + self.sample_rate = sample_rate + self.pad_multiple = pad_multiple + self.out_dir = out_dir + self.combined_file_list = [] + for fidx, file_list in enumerate(file_lists): + base_audio_dir = base_audio_dirs[fidx] + dataset_name = dataset_names[fidx] + for file_path in file_list: + audio_file_path = os.path.join(base_audio_dir, file_path) + self.combined_file_list.append({ + "file_path": file_path, + "audio_file_path": audio_file_path, + "dataset_name": dataset_name + }) + + def __len__(self): + return len(self.combined_file_list) + + def get_wav_from_filepath(self, file_path): + features = AudioSegment.segment_from_file( + file_path, target_sr=self.sample_rate, n_segments=-1, trim=False, + ) + audio_samples = features.samples + audio = torch.tensor(audio_samples) + audio = torch.nn.functional.pad(audio, (0, self.pad_multiple - audio.size(0) % self.pad_multiple), value=0) + audio_length = torch.tensor(audio.size(0)).long() + return audio, audio_length + + def __getitem__(self, idx): + file_path = self.combined_file_list[idx]["file_path"] + audio_file_path = self.combined_file_list[idx]["audio_file_path"] + dataset_name = self.combined_file_list[idx]["dataset_name"] + assert not file_path.startswith("/"), "file_path should be relative" + audio, audio_length = self.get_wav_from_filepath(audio_file_path) + codec_file_path_rel = file_path.replace(".wav", ".pt").replace(".flac", ".pt") + return { + "audio": audio, + "audio_length": audio_length, + "file_path": file_path, + "codec_file_path": os.path.join(self.out_dir, dataset_name, codec_file_path_rel) + } + + def collate_fn(self, batch): + audios_padded = [] + audio_lengths = [] + file_paths = [] + codec_file_paths = [] + max_audio_length = max(item["audio_length"].item() for item in batch) + for item in batch: + audio = torch.nn.functional.pad( + item["audio"], (0, max_audio_length - item["audio"].size(0)), value=0 + ) + audios_padded.append(audio) + audio_lengths.append(item["audio_length"]) + file_paths.append(item["file_path"]) + codec_file_paths.append(item["codec_file_path"]) + + return { + "audios": torch.stack(audios_padded), + "audio_lengths": torch.stack(audio_lengths), + "audio_file_paths": file_paths, + "codec_file_paths": codec_file_paths + } + + +class CodecExtractor(pl.LightningModule): + def __init__(self, model_path): + super().__init__() + self.codec_model = AudioCodecModel.restore_from(model_path, strict=False) + self.codec_model.eval() + + def forward(self, batch): + with torch.no_grad(): + codes, codes_lengths = self.codec_model.encode(audio=batch["audios"], audio_len=batch["audio_lengths"]) + return codes, codes_lengths + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + codes, codes_lengths = self(batch) + for i, file_path in enumerate(batch["codec_file_paths"]): + # get directory from file path + item_codes = codes[i, :, :codes_lengths[i]] # 8, T + torch.save(item_codes.cpu().type(torch.int16), file_path) + return None + +def read_manifest(manifest_path): + records = [] + with open(manifest_path, 'r') as f: + all_lines = f.readlines() + for line in all_lines: + line = line.strip() + record = json.loads(line) + records.append(record) + return records + +def write_manifest(manifest_path, records): + with open(manifest_path, 'w') as f: + file_str = "" + for record in records: + file_str += json.dumps(record) + "\n" + file_str = file_str.strip() + f.write(file_str) + print("Wrote {} records to: {}".format(len(records), manifest_path)) + +@rank_zero_only +def update_manifests(manifests, save_dir, dataset_names, codec_model_name): + for midx, manifest in enumerate(manifests): + records = read_manifest(manifest) + for ridx, record in enumerate(records): + audio_codes_path = record["audio_filepath"].replace(".wav", ".pt").replace(".flac", ".pt") + audio_codes_path = os.path.join(save_dir, dataset_names[midx], audio_codes_path) + record["target_audio_codes_path"] = audio_codes_path + if ridx % 10 == 0: + assert os.path.exists(audio_codes_path), "Audio codes not found: {}".format(audio_codes_path) + + if "context_audio_filepath" in record: + context_audio_codes_path = record["context_audio_filepath"].replace(".wav", ".pt").replace(".flac", ".pt") + context_audio_codes_path = os.path.join(save_dir, dataset_names[midx], context_audio_codes_path) + record["context_audio_codes_path"] = context_audio_codes_path + if ridx % 10 == 0: + assert os.path.exists(context_audio_codes_path), "Context audio codes not found: {}".format(context_audio_codes_path) + + write_manifest(manifest.replace(".json", "_withAudioCodes_{}.json".format(codec_model_name)), records) + +def prepare_directories(base_save_dir, codec_model_name, manifests, audio_base_dirs, dataset_names): + print("In prepare_directories") + save_dir = os.path.join(base_save_dir, codec_model_name) + file_lists = [] + for midx, manifest in enumerate(manifests): + records = read_manifest(manifest) + unique_audio_file_paths = {} + for record in records: + unique_audio_file_paths[record["audio_filepath"]] = 1 + if "context_audio_filepath" in record: + unique_audio_file_paths[record["context_audio_filepath"]] = 1 + file_list = list(unique_audio_file_paths.keys()) + file_lists.append(file_list) + for file_path in file_list: + dir_path = os.path.dirname(file_path) + out_dir_path = os.path.join(save_dir, dataset_names[midx], dir_path) + if not os.path.exists(out_dir_path): + os.makedirs(out_dir_path, exist_ok=True) + print("Created directories for saving audio codes at: ", save_dir, len(file_lists)) + return save_dir, file_lists + +if __name__ == "__main__": + """ + Usage: + python scripts/magpietts/codec_extraction.py \ + --manifests /home/pneekhara/2023/SimpleT5NeMo/manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths.json \ + --audio_base_dirs /datap/misc/Datasets/VCTK-Corpus \ + --codec_model_name codec21Khz_no_eliz \ + --dataset_names smallvctk \ + --save_dir /home/pneekhara/2023/SimpleT5NeMo/codec_outputs_21Khz \ + --codec_model_path /datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo \ + --sample_rate 22050 \ + --pad_multiple 1024 \ + --devices -1 \ + --num_nodes 1 + """ + parser = argparse.ArgumentParser() + parser.add_argument("--manifests", type=str) + parser.add_argument("--audio_base_dirs", type=str) + parser.add_argument("--dataset_names", type=str) + parser.add_argument("--save_dir", type=str) + parser.add_argument("--codec_model_path", type=str) + parser.add_argument("--codec_model_name", type=str) + parser.add_argument("--sample_rate", type=int) + parser.add_argument("--pad_multiple", type=int) + parser.add_argument("--devices", type=int, default=-1) + parser.add_argument("--num_nodes", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--num_workers", type=int, default=4) + args = parser.parse_args() + + trainer = Trainer( + devices=args.devices, + accelerator="gpu", + strategy=DDPStrategy(find_unused_parameters=False), + num_nodes=args.num_nodes, + log_every_n_steps=1, + max_epochs=1, + logger=False, + ) + + audio_base_dirs = args.audio_base_dirs.split(",") + dataset_names = args.dataset_names.split(",") + manifests = args.manifests.split(",") + + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + if rank == 0: + save_dir, file_lists = prepare_directories( + args.save_dir, + args.codec_model_name, + manifests, + audio_base_dirs, + dataset_names + ) + results = [save_dir, file_lists] + else: + results = [None, None] + torch.distributed.broadcast_object(results, src=0) + save_dir, file_lists = results + else: + save_dir, file_lists = prepare_directories( + args.save_dir, + args.codec_model_name, + manifests, + audio_base_dirs, + dataset_names + ) + + codec_extractor = CodecExtractor(args.codec_model_path) + + # Dataset and DataLoader + dataset = AudioDataset( + file_lists=file_lists, + base_audio_dirs=audio_base_dirs, + dataset_names=dataset_names, + out_dir=save_dir, + sample_rate=args.sample_rate, + pad_multiple=args.pad_multiple, + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + shuffle=False, + collate_fn=dataset.collate_fn, + ) + + # Run prediction (Saves the audio codes to files) + trainer.predict(codec_extractor, dataloaders=dataloader) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + update_manifests(manifests, save_dir, dataset_names, args.codec_model_name) \ No newline at end of file diff --git a/scripts/magpietts/dpo/create_text_contextpairs.py b/scripts/magpietts/dpo/create_text_contextpairs.py index 19f00835c043..2e7e1ee1184e 100644 --- a/scripts/magpietts/dpo/create_text_contextpairs.py +++ b/scripts/magpietts/dpo/create_text_contextpairs.py @@ -36,7 +36,7 @@ def main(): The resulting dataset is saved as a JSON manifest file. Example usage: - python scripts/t5tts/dpo/create_text_contextpairs.py \ + python scripts/magpietts/dpo/create_text_contextpairs.py \ --challenging_texts /Data/DPOPairsInputDatav2/challenging_with_short.txt \ --regular_texts_for_audiocontext /Data/DPOPairsInputDatav2/regular_texts_for_audiocontext.txt \ --regular_texts_for_textcontext /Data/DPOPairsInputDatav2/regular_texts_for_textcontext.txt \ diff --git a/scripts/magpietts/eval_squimmos.py b/scripts/magpietts/eval_squimmos.py new file mode 100644 index 000000000000..8c40994ed309 --- /dev/null +++ b/scripts/magpietts/eval_squimmos.py @@ -0,0 +1,67 @@ +from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE +import os +import json +import torch +import argparse +import librosa +import scipy.stats as stats +import numpy as np + +def find_sample_audios(audio_dir): + file_list = [] + for f in os.listdir(audio_dir): + if "predicted_audio" in f and f.endswith(".wav"): + audio_number = int(f.split("_")[-1].split(".wav")[0]) + file_list.append((audio_number, os.path.join(audio_dir, f))) + file_list.sort() + file_list = [t[1] for t in file_list] + return file_list + +def compute_mean_and_confidence_interval(measurements, confidence=0.95): + mean = np.mean(measurements) + std_err = stats.sem(measurements) + + confidence_interval = std_err * stats.t.ppf((1 + confidence) / 2, len(measurements) - 1) + + return "{:.4f} +/- {:.4f}".format(mean, confidence_interval), mean, confidence_interval + +def main(): + parser = argparse.ArgumentParser(description='Evaluate Squim MOS') + parser.add_argument('--exp_base_dir', type=str, default="/datap/misc/ContinuousEvalResults/NewTransformerKoelTTS") + parser.add_argument('--audio_dirs', type=str, default="svencoder_small_sp_ks3_onlyphoneme_epoch242_Temp0.6_Topk80_Cfg_False_1.0_libri_val") + args = parser.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + squim_mos_model = SQUIM_SUBJECTIVE.get_model().to(device) + + squim_score_list = [] + if args.audio_dirs == "all": + audio_dirs = [d for d in os.listdir(args.exp_base_dir) if os.path.isdir(os.path.join(args.exp_base_dir, d))] + else: + audio_dirs = args.audio_dirs.split(",") + out_file = os.path.join(args.exp_base_dir, "squim_mos_score.csv") + for audio_dir in audio_dirs: + print("Evaluating audio dir: ", audio_dir) + audio_dir_path = os.path.join(args.exp_base_dir, audio_dir, "audio") + audio_files = find_sample_audios(audio_dir_path) + for audio_file in audio_files: + pred_wav, sr = librosa.load(audio_file, sr=16000) + pred_wav = torch.tensor(pred_wav).to(device).unsqueeze(0) + + gt_path = audio_file.replace("predicted_audio", "target_audio") + gt_wav, sr = librosa.load(gt_path, sr=16000) + gt_wav = torch.tensor(gt_wav).to(device).unsqueeze(0) + with torch.no_grad(): + squm_mos_score = squim_mos_model(pred_wav, gt_wav) + squim_score_list.append(squm_mos_score.item()) + + mean_with_ci, mean, confidence_interval = compute_mean_and_confidence_interval(squim_score_list) + # Add to audio_dir,mean_with_ci to csv + with open(out_file, "a") as f: + f.write(audio_dir + "," + mean_with_ci + "\n") + print("Audio dir: ", audio_dir, "Mean with CI: ", mean_with_ci) + print("Wrote to file: ", out_file) + + +if __name__ == "__main__": + main() diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index b44c1601456a..77caaa788c00 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -1,8 +1,8 @@ dataset_meta_info = { 'vctk': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths.json', - 'audio_dir' : '/datap/misc/Datasets/VCTK-Corpus', - 'feature_dir' : '/datap/misc/Datasets/VCTK-Corpus', + 'manifest_path' : '/home/jasoli/data_prime/manifests/rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json', + 'audio_dir' : '/home/jasoli/data_prime/RIVA-TTS/en/', + 'feature_dir' : '/home/jasoli/data_prime/codecs', }, 'riva_challenging': { 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/challengingLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json', diff --git a/scripts/t5tts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py similarity index 94% rename from scripts/t5tts/evaluate_generated_audio.py rename to scripts/magpietts/evaluate_generated_audio.py index ecf75658a1d9..2d946d3a5dd6 100644 --- a/scripts/t5tts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -22,7 +22,7 @@ def find_sample_audios(audio_dir): file_list.sort() file_list = [t[1] for t in file_list] return file_list - + def read_manifest(manifest_path): records = [] with open(manifest_path, 'r') as f: @@ -35,18 +35,18 @@ def read_manifest(manifest_path): def process_text(input_text): # Convert text to lowercase lower_case_text = input_text.lower() - + # Remove commas from text no_comma_text = lower_case_text.replace(",", "") - + # Replace "-" with spaces no_dash_text = no_comma_text.replace("-", " ") - + # Replace double spaces with single space single_space_text = " ".join(no_dash_text.split()) single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) - + return single_space_text def transcribe_with_whisper(whisper_model, whisper_processor, audio_path, language, device): @@ -66,7 +66,7 @@ def transcribe_with_whisper(whisper_model, whisper_processor, audio_path, langua def extract_embedding(model, extractor, audio_path, device, sv_model_type): speech_array, sampling_rate = librosa.load(audio_path, sr=16000) - + if sv_model_type == "wavlm": inputs = extractor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_values.to(device) with torch.no_grad(): @@ -74,7 +74,7 @@ def extract_embedding(model, extractor, audio_path, device, sv_model_type): else: # Titanet with torch.no_grad(): embeddings = model.get_embedding(audio_path).squeeze() - + return embeddings.squeeze() def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large"): @@ -89,7 +89,7 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="stt_en_conformer_transducer_large") elif asr_model_name == "nvidia/parakeet-ctc-0.6b": asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b") - + # asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="nvidia/parakeet-tdt-1.1b") asr_model = asr_model.to(device) asr_model.eval() @@ -104,15 +104,15 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo speaker_verification_model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv').to(device).eval() else: feature_extractor = None - speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') speaker_verification_model = speaker_verification_model.to(device) speaker_verification_model.eval() - speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_small') + speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_small') speaker_verification_model_alternate = speaker_verification_model_alternate.to(device) speaker_verification_model_alternate.eval() - + filewise_metrics = [] pred_texts = [] @@ -125,18 +125,14 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo gt_audio_filepath = os.path.join(audio_dir, gt_audio_filepath) if context_audio_filepath is not None: context_audio_filepath = os.path.join(audio_dir, context_audio_filepath) - + pred_audio_filepath = audio_file_lists[ridx] if language == "en": with torch.no_grad(): - if asr_model_name == "stt_en_conformer_transducer_large": - pred_text = asr_model.transcribe([pred_audio_filepath])[0][0] - gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0][0] - else: - pred_text = asr_model.transcribe([pred_audio_filepath])[0] - gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0] - + # import ipdb; ipdb.set_trace() + pred_text = asr_model.transcribe([pred_audio_filepath])[0].text pred_text = process_text(pred_text) + gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0].text gt_audio_text = process_text(gt_audio_text) else: pred_text = transcribe_with_whisper(whisper_model, whisper_processor, pred_audio_filepath, language, device) @@ -148,15 +144,15 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo gt_text = process_text(record['normalized_text']) else: gt_text = process_text(record['text']) - + detailed_cer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=True) detailed_wer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=False) - + print("{} GT Text:".format(ridx), gt_text) print("{} Pr Text:".format(ridx), pred_text) # Format cer and wer to 2 decimal places print("CER:", "{:.4f} | WER: {:.4f}".format(detailed_cer[0], detailed_wer[0])) - + pred_texts.append(pred_text) gt_texts.append(gt_text) gt_audio_texts.append(gt_audio_text) @@ -181,8 +177,8 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo pred_context_ssim_alternate = torch.nn.functional.cosine_similarity(pred_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() gt_context_ssim_alternate = torch.nn.functional.cosine_similarity(gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() - - + + filewise_metrics.append({ 'gt_text': gt_text, @@ -202,12 +198,12 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo 'pred_audio_filepath': pred_audio_filepath, 'context_audio_filepath': context_audio_filepath }) - + filewise_metrics_keys_to_save = ['cer', 'wer', 'pred_context_ssim', 'pred_text', 'gt_text', 'gt_audio_filepath', 'pred_audio_filepath', 'context_audio_filepath'] filtered_filewise_metrics = [] for m in filewise_metrics: filtered_filewise_metrics.append({k: m[k] for k in filewise_metrics_keys_to_save}) - + # Sort filewise metrics by cer in reverse filewise_metrics.sort(key=lambda x: x['cer'], reverse=True) @@ -244,10 +240,10 @@ def main(): assert args.evalset in dataset_meta_info args.manifest_path = dataset_meta_info[args.evalset]['manifest_path'] args.audio_dir = dataset_meta_info[args.evalset]['audio_dir'] - + evaluate(args.manifest_path, args.audio_dir, args.generated_audio_dir, args.whisper_language, sv_model_type="wavlm", asr_model_name="nvidia/parakeet-ctc-0.6b") - + if __name__ == "__main__": main() diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 2e0f7217e8c6..a565b73267a7 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -1,5 +1,5 @@ -from nemo.collections.tts.models import T5TTS_Model -from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset +from nemo.collections.tts.models import MagpieTTS_Model +from nemo.collections.tts.data.text_to_speech_dataset import MagpieTTSDataset from omegaconf.omegaconf import OmegaConf, open_dict import os import glob @@ -29,18 +29,18 @@ def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0 return metrics def run_inference( - hparams_file, - checkpoint_file, - datasets, - out_dir, - temperature, - topk, - codecmodel_path, - use_cfg, - cfg_scale, - batch_size, - sv_model, - asr_model_name, + hparams_file, + checkpoint_file, + datasets, + out_dir, + temperature, + topk, + codecmodel_path, + use_cfg, + cfg_scale, + batch_size, + sv_model, + asr_model_name, num_repeats=1, apply_attention_prior=False, attention_prior_epsilon=1e-3, @@ -65,7 +65,7 @@ def run_inference( model_cfg.validation_ds = None - model = T5TTS_Model(cfg=model_cfg) + model = MagpieTTS_Model(cfg=model_cfg) model.use_kv_cache_for_inference = True # Load weights from checkpoint file @@ -79,13 +79,13 @@ def run_inference( checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}_LT_{}_sv_{}".format( - checkpoint_name, - temperature, - topk, - use_cfg, - cfg_scale, - apply_attention_prior, - attention_prior_epsilon, + checkpoint_name, + temperature, + topk, + use_cfg, + cfg_scale, + apply_attention_prior, + attention_prior_epsilon, attention_prior_lookahead_window, start_prior_after_n_audio_steps, "".join([str(l) for l in estimate_alignment_from_layers]) if estimate_alignment_from_layers is not None else "None", @@ -114,7 +114,7 @@ def run_inference( if context_durration_min < 5.0 and context_durration_max > 5.0: context_durration_min = 5.0 context_durration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though. - test_dataset = T5TTSDataset( + test_dataset = MagpieTTSDataset( dataset_meta=dataset_meta, sample_rate=model_cfg.sample_rate, min_duration=0.5, @@ -128,7 +128,7 @@ def run_inference( audio_eos_id=model.audio_eos_id, num_audio_codebooks=model_cfg.num_audio_codebooks, prior_scaling_factor=None, - load_cached_codes_if_available=dataset_meta_info[dataset].get('load_cached_codes_if_available', True), + load_cached_codes_if_available=False, dataset_type='test', tokenizer_config=None, load_16khz_audio=model.model_type == 'single_encoder_sv_tts', @@ -158,17 +158,17 @@ def run_inference( batch_cuda[key] = batch[key].cuda() else: batch_cuda[key] = batch[key] - + import time st = time.time() predicted_audio, predicted_audio_lens, _, _, rtf_metrics, cross_attention_maps, _ = model.infer_batch( - batch_cuda, - max_decoder_steps=440, - temperature=temperature, - topk=topk, - use_cfg=use_cfg, - cfg_scale=cfg_scale, - return_cross_attn_probs=True, + batch_cuda, + max_decoder_steps=440, + temperature=temperature, + topk=topk, + use_cfg=use_cfg, + cfg_scale=cfg_scale, + return_cross_attn_probs=True, apply_attention_prior=apply_attention_prior, prior_epsilon=attention_prior_epsilon, lookahead_window_size=attention_prior_lookahead_window, @@ -199,11 +199,11 @@ def run_inference( if os.path.exists(target_audio_path): shutil.copy(target_audio_path, os.path.join(audio_dir, f"target_audio_{item_idx}.wav")) item_idx += 1 - + mean_rtf_metrics = {} for key in all_rtf_metrics[0]: mean_rtf_metrics[key] = float(np.mean([m[key] for m in all_rtf_metrics])) - + metrics, filewise_metrics = evaluate_generated_audio.evaluate( dataset_meta[dataset]['manifest_path'], dataset_meta[dataset]['audio_dir'], @@ -215,11 +215,11 @@ def run_inference( metrics_n_repeated.append(metrics) with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: json.dump(metrics, f, indent=4) - + with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f: # Indent for better readability json.dump(filewise_metrics, f, indent=4) - + with open(os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}.json"), "w") as f: json.dump(mean_rtf_metrics, f, indent=4) @@ -231,8 +231,8 @@ def run_inference( f.write(f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']}\n") print(f"Wrote metrics for {checkpoint_name} and {dataset} to {all_experiment_csv}") - metric_keys = ['cer_filewise_avg', 'wer_filewise_avg', 'cer_cumulative', 'wer_cumulative', - 'ssim_pred_gt_avg', 'ssim_pred_context_avg', 'ssim_gt_context_avg', + metric_keys = ['cer_filewise_avg', 'wer_filewise_avg', 'cer_cumulative', 'wer_cumulative', + 'ssim_pred_gt_avg', 'ssim_pred_context_avg', 'ssim_gt_context_avg', 'ssim_pred_gt_avg_alternate', 'ssim_pred_context_avg_alternate', 'ssim_gt_context_avg_alternate', 'cer_gt_audio_cumulative', 'wer_gt_audio_cumulative' ] @@ -291,7 +291,7 @@ def main(): assert len(hparam_files) == len(checkpoint_files), "Number of hparams files and checkpoint files should be the same." for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): run_inference( - hparams_file=hparams_file, + hparams_file=hparams_file, checkpoint_file=checkpoint_file, datasets=args.datasets.split(","), out_dir=args.out_dir, @@ -334,12 +334,12 @@ def main(): except: print(f"Skipping experiment {exp_name} as hparams or last checkpoint not found.") continue - last_checkpoint_path_draco = last_checkpoint.replace(BASE_EXP_DIR, DRACO_EXP_DIR) + last_checkpoint_path_draco = last_checkpoint.replace(BASE_EXP_DIR, DRACO_EXP_DIR) epoch_num = last_checkpoint.split("epoch=")[1].split("-")[0] checkpoint_copy_path = os.path.join(args.local_ckpt_dir, f"{exp_name}_epoch_{epoch_num}.ckpt") hparams_copy_path = os.path.join(args.local_ckpt_dir, f"{exp_name}_hparams.yaml") - + scp_command = f"scp {args.server_address}:{last_checkpoint_path_draco} {checkpoint_copy_path}" print(f"Running command: {scp_command}") os.system(scp_command) @@ -353,8 +353,8 @@ def main(): print("Hparams file path: ", hparams_copy_path) print("Checkpoint file path: ", checkpoint_copy_path) run_inference( - hparams_copy_path, - checkpoint_copy_path, + hparams_copy_path, + checkpoint_copy_path, datasets=args.datasets.split(","), out_dir=args.out_dir, temperature=args.temperature, @@ -375,7 +375,7 @@ def main(): confidence_level=args.confidence_level, use_local_transformer=args.use_local_transformer ) - + if __name__ == '__main__': main() \ No newline at end of file diff --git a/scripts/tts_dataset_to_lhotse/README.md b/scripts/tts_dataset_to_lhotse/README.md new file mode 100644 index 000000000000..436cbdbc4d20 --- /dev/null +++ b/scripts/tts_dataset_to_lhotse/README.md @@ -0,0 +1,82 @@ +# Everything Speech Data + +### Single turn speech to speech data + +Our single turn speech to speech data is in the form of conversations such that it will be easy to extend to multi-turn conversations. In this section we will go through the following: + +- Raw manifest format +- Lhotse cuts and shar format +- Creating Lhotse shar and cuts from raw manifest + +#### Raw manifest format + +Users need to get their manifests in the following format for `scripts/speech_data_generation/create_shar.py` to work. Each datapoint in the manifest should be in this format: + +``` +{ + 'sample_id': '', + 'normalized_answer_wer': (Optional, this is useful if speech data was synthesized and we want to filter based on wer of the synthesized speech), + 'normalized_answer_cer': (Optional, this is useful if speech data was synthesized and we want to filter based on cer of the synthesized speech), + 'conversations': [ + {'value': '', + 'from': 'user', + 'type': 'audio', + 'duration': (Optional | during shar creation it is automatically calculated), + 'lang': (Optional), + 'instruction': ''(This field is needed for s2s and in direct_s2s this is not needed) + }, + {'value': '', + 'from': 'agent', + 'type': 'audio', + 'duration': (Optional | during shar creation it is automatically calculated), + 'lang': (Optional), + 'transcript': '' + } + ] +} +``` + +#### Lhotse cuts and shar format + +There will be 3 types of files generated after you run `scripts/speech_data_generation/create_shar.py`: + +- cuts.{some_number}.jsonl.gz +- recording.{some_number}.tar +- target_audio.{some_number}.tar + +**recording.{some_number}.tar** - tarred user (input) speech wav files + +**target_audio.{some_number}.tar** - tarred agent (target) speech wav files + +**cuts.{some_number}.jsonl.gz** - You can think of this as the Lhotse manifest. The format or the fields are explained as below (This document will only go over the fields which are used during training/inference) + +This is what a typical cut would look like, which is one datapoint in any of the cuts.{some_number}.jsonl.gz files: +``` +MonoCut(id='squadv2_5705e3a452bb891400689658-2', start=0, duration=17.345306122448978, channel=0, supervisions=[SupervisionSegment(id='squadv2_5705e3a452bb891400689658-2', recording_id='squadv2_5705e3a452bb891400689658-2', start=0, duration=17.345306122448978, channel=0, text='Transcribe and answer:', language='EN', speaker='user', gender=None, custom=None, alignment=None), SupervisionSegment(id='squadv2_5705e3a452bb891400689658-2', recording_id='squadv2_5705e3a452bb891400689658-2', start=0, duration=1.1493877551020408, channel=0, text='NCT of Delhi', language='EN', speaker='agent', gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='squadv2_5705e3a452bb891400689658', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=44100, num_samples=764928, duration=17.345306122448978, channel_ids=[0], transforms=None), custom={'target_audio': Recording(id='_lustre_fsw_portfolios_llmservice_projects_llmservice_nemo_speechlm_data_speech_QA_outputs_speechall_squadv2_train_normalized___audios_squadv2_5705e3a452bb891400689658_synthesized_normalized_answer_audio', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=22050, num_samples=25344, duration=1.1493877551020408, channel_ids=[0], transforms=None), 'shard_origin': PosixPath('/lustre/fs7/portfolios/llmservice/projects/llmservice_nemo_speechlm/data/s2s_synthetic_data/s2s_lhotse_with_wavs/test2/cuts.000000.jsonl.gz'), 'shar_epoch': 0}) +``` + +Explaination of the fields: + +- id: Unique to the datapoint, this is used to reference recording wavs from the corresponding tarred files +- duration: This is the duration of the recording or source audio +- supervisions: Is a list of 2 elements (1 user turn and 1 agent turn) containing the metadata related to each turn. +- supervisions[0].text: Instruction from the user +- supervisions[0].speaker: user +- supervisions[0].language: language of input audio +- supervisions[1].text: transcript of target audio +- supervisions[1].speaker: agent +- supervisions[1].language: language of target audio +- custom['target_audio'] - This is the agent or target audio also in the form of a Recording. It has it's own duration, sampling_rate and id +- custom['target_audio'].id - is used to reference target_audios from target_audio tar file. +- custom['target_audio'].duration - self explainatory +- custom['target_audio'].sampling_rate - self explainatory + +#### Creating Lhotse shar and cuts from raw manifest + +To create Lhotse shar and cuts from raw manifests simple run the following command: +``` +python scripts/speech_to_speech_data_generation/create_shars.py \ +--manifest= \ +--out_shar_dir= \ +--num_shard= +``` \ No newline at end of file diff --git a/scripts/tts_dataset_to_lhotse/create_shars.py b/scripts/tts_dataset_to_lhotse/create_shars.py new file mode 100644 index 000000000000..ad7663e86fc7 --- /dev/null +++ b/scripts/tts_dataset_to_lhotse/create_shars.py @@ -0,0 +1,167 @@ +from pathlib import Path +import json +import os +import argparse +import shutil +import csv +import soundfile as sf +### from nemo.collections.tts.models import AudioCodecModel +import librosa +import torch +import numpy as np +from tqdm import tqdm +from matplotlib import pyplot as plt + +from lhotse import CutSet, SupervisionSegment, Recording, AudioSource +from lhotse.cut.base import Cut +from lhotse.features.base import Features, FeatureSet +from lhotse.array import TemporalArray, Array +from lhotse.shar.writers import AudioTarWriter +from lhotse.audio import RecordingSet + +def json_reader(filename): + with open(filename) as f: + for line in f: + yield json.loads(line) + + +def create_shar_from_manifest( + manifest, audio_root_path, out_shar_dir, shard_size=1000 +): + in_manifest = list(json_reader(manifest)) + print(f"...loaded {manifest} # of datapoints {len(in_manifest)}") + num_shard = int(len(in_manifest) / shard_size) + if len(in_manifest) % shard_size != 0: + shard_size += 1 + print(f"shard_size {shard_size} num_shards {num_shard}") + + user_recordings = [] + answer_list = [] + instructions = [] + source_language = [] + target_language = [] + target_recordings = [] + for i, line in tqdm(enumerate(in_manifest)): + # For single turn convs is a list of 2 elements + # First element is user speech and second is agent speech + + # User_Speech + context_audio_path = line["context_audio_filepath"] + + user_recording = Recording.from_file(os.path.join(audio_root_path, context_audio_path)) + user_recordings.append(user_recording) + + # This are the context text, this could be different things like a simple instruction or details about speaker voice + instructions.append(" ") + + # Language source + if "lang" in line: + language = line["lang"] + elif "language" in line: + language = line["language"] + elif "Language:" in str(line["speaker"]): + language = line["speaker"].split("Language:")[1].split(" ")[0] + else: + language = "en" + + source_language.append(language) + + # Loading agent audio and using only the extracted features as nd.array + target_recordings.append(Recording.from_file(os.path.join(audio_root_path, line["audio_filepath"]))) + # Agent answer transcript + answer_list.append(line["text"]) + # Language target + target_language.append(language) + + + print("Done extracting data from manifest") + print(len(user_recordings)) + cuts = CutSet.from_manifests(recordings=RecordingSet.from_recordings(user_recordings)) + + # Attach text + for i, cut in tqdm(enumerate(cuts)): + cut.supervisions.append( + SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=0, + duration=cut.recording.duration, + text=instructions[i], + speaker="user", + language=source_language[i].upper(), + ), + ) + cut.supervisions.append( + SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=0, + duration=target_recordings[i].duration, + text=answer_list[i], + speaker="agent", + language=target_language[i].upper(), + ), + ) + cut.target_audio = target_recordings[i] + + print("...Making Shars") + out_shar_dir = Path(out_shar_dir) + out_shar_dir.mkdir(parents=True, exist_ok=True) + shard_size = shard_size + assert len(user_recordings) % shard_size != 0, "Lhotse breaks if feat_list is a multiple of shard_size" + exported = cuts.to_shar( + out_shar_dir, fields={"recording": "wav"}, num_jobs=4, shard_size=shard_size + ) + print(f"...share created") + + print(f"...Exporting target_audio to tar files") + for i, path in tqdm(enumerate(exported["cuts"])): + path = path[0] + out_path = path.replace("cuts", "target_audio").replace(".jsonl.gz", ".tar") + with AudioTarWriter( + out_path, shard_size=None, format="flac" + ) as writer: + for cut in CutSet.from_file(path): + writer.write(cut.id, cut.target_audio.load_audio(), manifest=cut.target_audio, sampling_rate=22050) + print(f"...Exported target_audio to tar files") + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--manifest', + type=str, + default="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/manifests/hifitts__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5.json", + ) + parser.add_argument( + '--audio_root_path', + type=str, + default="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/hi_fi_tts_v0/", + ) + parser.add_argument( + '--out_shar_dir', + type=str, + default="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/hifitts/", + ) + parser.add_argument( + '--shard_size', + type=int, + default=1000, + ) + + args = parser.parse_args() + print(f"manifest {args.manifest}") + print(f"audio_root_path {args.audio_root_path}") + print(f"out_shar_dir {args.out_shar_dir}") + print(f"num_shard {args.shard_size}") + + create_shar_from_manifest( + manifest=args.manifest, + audio_root_path=args.audio_root_path, + out_shar_dir=args.out_shar_dir, + shard_size=args.shard_size, + ) + +if __name__ == "__main__": + main() + + diff --git a/t5tts_inference_multiturndialogues.ipynb b/t5tts_inference_multiturndialogues.ipynb new file mode 100644 index 000000000000..20f1fa7ff645 --- /dev/null +++ b/t5tts_inference_multiturndialogues.ipynb @@ -0,0 +1,320 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "466ccdc5", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.tts.models import T5TTS_Model\n", + "from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset, DatasetSample\n", + "from omegaconf.omegaconf import OmegaConf, open_dict\n", + "import torch\n", + "import os\n", + "import soundfile as sf\n", + "from IPython.display import display, Audio\n", + "import os\n", + "import numpy as np\n", + "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "6659ae78", + "metadata": {}, + "source": [ + "## Set checkpoint and other file paths on your machine" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04445f11", + "metadata": {}, + "outputs": [], + "source": [ + "hparams_file = \"/datap/misc/duplexcheckpoints/blackwell/duplex_blackwell_medium_decoder_noexpresso_onlyphonemeFT_hparams.yaml\"\n", + "checkpoint_file = \"/datap/misc/duplexcheckpoints/blackwell/duplex_blackwell_medium_decoder_withTC_fromroycheckpoint_lowsestvalloss.ckpt\"\n", + "codecmodel_path = \"/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo\"\n", + "out_dir = \"/datap/misc/t5tts_inference_notebook_samples\"\n", + "if not os.path.exists(out_dir):\n", + " os.makedirs(out_dir)\n", + "\n", + "dummy_audio_filepath = os.path.join(out_dir, \"dummy_audio.wav\")\n", + "sf.write(dummy_audio_filepath, np.zeros(22050 * 3), 22050)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "01e24b14", + "metadata": {}, + "source": [ + "## Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87bf66f9", + "metadata": {}, + "outputs": [], + "source": [ + "model_cfg = OmegaConf.load(hparams_file).cfg\n", + "\n", + "with open_dict(model_cfg):\n", + " model_cfg.codecmodel_path = codecmodel_path\n", + " if hasattr(model_cfg, 'text_tokenizer'):\n", + " # Backward compatibility for models trained with absolute paths in text_tokenizer\n", + " model_cfg.text_tokenizer.g2p.phoneme_dict = \"scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt\"\n", + " model_cfg.text_tokenizer.g2p.heteronyms = \"scripts/tts_dataset_files/heteronyms-052722\"\n", + " model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0\n", + " model_cfg.train_ds = None\n", + " model_cfg.validation_ds = None\n", + "\n", + "\n", + "model = T5TTS_Model(cfg=model_cfg)\n", + "# Load weights from checkpoint file\n", + "print(\"Loading weights from checkpoint\")\n", + "ckpt = torch.load(checkpoint_file)\n", + "model.load_state_dict(ckpt['state_dict'])\n", + "print(\"Loaded weights.\")\n", + "\n", + "model.use_kv_cache_for_inference = True\n", + "\n", + "model.cuda()\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "843167ec", + "metadata": {}, + "outputs": [], + "source": [ + "test_dataset = T5TTSDataset(\n", + " dataset_meta={},\n", + " sample_rate=model_cfg.sample_rate,\n", + " min_duration=0.5,\n", + " max_duration=20,\n", + " codec_model_downsample_factor=model_cfg.codec_model_downsample_factor,\n", + " bos_id=model.bos_id,\n", + " eos_id=model.eos_id,\n", + " context_audio_bos_id=model.context_audio_bos_id,\n", + " context_audio_eos_id=model.context_audio_eos_id,\n", + " audio_bos_id=model.audio_bos_id,\n", + " audio_eos_id=model.audio_eos_id,\n", + " num_audio_codebooks=model_cfg.num_audio_codebooks,\n", + " prior_scaling_factor=None,\n", + " load_cached_codes_if_available=True,\n", + " dataset_type='test',\n", + " tokenizer_config=None,\n", + " load_16khz_audio=model.model_type == 'single_encoder_sv_tts',\n", + " use_text_conditioning_tokenizer=model.use_text_conditioning_encoder,\n", + " pad_context_text_to_max_duration=model.pad_context_text_to_max_duration,\n", + " context_duration_min=model.cfg.get('context_duration_min', 5.0),\n", + " context_duration_max=model.cfg.get('context_duration_max', 5.0),\n", + ")\n", + "test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "cbd78348", + "metadata": {}, + "source": [ + "### Set dialogues to generate\n", + "\n", + "Each item in the list can be a single-turn or a multi-turn dialogue.\n", + "\n", + "[SPK-BWL-B-F] is the speaker tag for female speaker and [SPK-BWL-B-M] is the speaker tag for male speaker.\n", + "\n", + "ChatGPT prompt that can generate something like this:\n", + "\n", + "```\n", + "Generate dialogues for a 30 second podcast about .\n", + "The conversation should be between a male and female speaker formatted as follows:\n", + "[SPK-BWL-B-F] Sentence by a female speaker\n", + "[SPK-BWL-B-M] Sentence by a male speaker\n", + "where [SPK-BWL-B-F] and [SPK-BWL-B-M] indicate speaker tags. Dont have any quotation marks in the text, and if there are any numbers spell them out. Basically, keep the text normalized suitable for a TTS model. Keep the conversation fun and engaging with the speakers talking and responding to each other. \n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5e32a27", + "metadata": {}, + "outputs": [], + "source": [ + "dialogues = [\"[SPK-BWL-B-F] Have you been keeping up with multi-turn TTS models? They are so fascinating [SPK-BWL-B-M] Absolutely. The way they handle conversations feels almost human.\",\n", + "\"[SPK-BWL-B-F] I know. It is amazing how they can remember context across multiple exchanges [SPK-BWL-B-M] Exactly. It is not just about saying words anymore. They actually make the responses flow naturally.\",\n", + "\"[SPK-BWL-B-F] And it is not just for assistants or chatbots. I heard they are being used for things like audiobooks and interactive storytelling.\",\n", + "\"[SPK-BWL-B-M] That is true. Plus, they can even adjust tone to match emotions. Imagine hearing a story where the narrator sounds genuinely excited during a twist.\",\n", + "\"[SPK-BWL-B-F] Or even a bit sarcastic during a funny moment. That makes everything feel more real.\",\n", + "\"[SPK-BWL-B-M] For sure. It is like conversations with a machine are finally catching up to the way we actually talk.\",\n", + "\"[SPK-BWL-B-F] The future of tech keeps surprising me. Every day, there is something new to explore.\",]\n", + "dialogues = [d for d in dialogues if len(d) > 0]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "aa182540", + "metadata": {}, + "source": [ + "### Generation\n", + "\n", + "Below code generates 4 samples for each item in the dialogues list. Then asks, which one you like the best (index 0,1,2 or 3) and adds that to the already generated dialogue. You may modify the below code to automate and just select the first generation if you dont want to do this manually. After every dialogue it also plays the combined dialogues until that item." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "168f157d", + "metadata": {}, + "outputs": [], + "source": [ + "from pydub import AudioSegment\n", + "\n", + "def play_combined_audio(audio_files):\n", + " combined_audio = AudioSegment.empty()\n", + " for file in audio_files:\n", + " audio = AudioSegment.from_file(file)\n", + " combined_audio += audio\n", + " output_path = os.path.join(out_dir, \"combined_audio.wav\")\n", + " combined_audio.export(output_path, format=\"wav\")\n", + " display(Audio(output_path))\n", + " return output_path\n", + " \n", + "\n", + "\n", + "context_path = None\n", + "generated_audios = []\n", + "for didx, dialogue in enumerate(dialogues):\n", + " audio_dir = \"/\"\n", + " entry = {\n", + " \"audio_filepath\": dummy_audio_filepath,\n", + " \"duration\": 3.0,\n", + " \"text\": dialogue,\n", + " \"speaker\": \"dummy\",\n", + " }\n", + " if didx == 0:\n", + " entry[\"context_text\"] = \"MIXED SPEECH TTS\"\n", + " else:\n", + " entry['context_audio_filepath'] = context_path\n", + " entry['context_audio_duration'] = 5.0\n", + " \n", + " \n", + " \n", + " data_sample = DatasetSample(\n", + " dataset_name=\"sample\",\n", + " manifest_entry=entry,\n", + " audio_dir=audio_dir,\n", + " feature_dir=audio_dir,\n", + " text=entry['text'],\n", + " speaker=None,\n", + " speaker_index=0,\n", + " tokenizer_names=[\"english_phoneme\"]\n", + " )\n", + " test_dataset.data_samples = [data_sample]\n", + "\n", + " test_data_loader = torch.utils.data.DataLoader(\n", + " test_dataset,\n", + " batch_size=1,\n", + " collate_fn=test_dataset.collate_fn,\n", + " num_workers=0,\n", + " shuffle=False\n", + " )\n", + " \n", + " \n", + " item_idx = 0\n", + " for bidx, batch in enumerate(test_data_loader):\n", + " print(\"Processing batch {} out of {}\".format(bidx, len(test_data_loader)))\n", + " model.t5_decoder.reset_cache(use_cache=True)\n", + " batch_cuda ={}\n", + " for key in batch:\n", + " if isinstance(batch[key], torch.Tensor):\n", + " batch_cuda[key] = batch[key].cuda()\n", + " else:\n", + " batch_cuda[key] = batch[key]\n", + " import time\n", + " \n", + " candidates = []\n", + " for try_idx in range(4):\n", + " st = time.time()\n", + " predicted_audio, predicted_audio_lens, _, _ = model.infer_batch(\n", + " batch_cuda, \n", + " max_decoder_steps=500, \n", + " temperature=0.6, \n", + " topk=80, \n", + " use_cfg=True, \n", + " cfg_scale=1.6\n", + " )\n", + " print(\"generation time\", time.time() - st)\n", + " for idx in range(predicted_audio.size(0)):\n", + " predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()\n", + " predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]\n", + " audio_path = os.path.join(out_dir, f\"predicted_audio_{try_idx}_{didx}_{item_idx}.wav\")\n", + " sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)\n", + " print(\"Dialogue:\", item_idx, \"Candidate: \", try_idx)\n", + " display(Audio(audio_path))\n", + " candidates.append(audio_path)\n", + " \n", + " user_input = input(\"Enter Candidate number that sounds the best:\").strip().lower()\n", + " selected_audio_idx = int(user_input)\n", + " audio_path = candidates[selected_audio_idx]\n", + " item_idx += 1\n", + " generated_audios.append(audio_path)\n", + " print(\"Podcast generated until now:\")\n", + " combined_audio_path = play_combined_audio(generated_audios)\n", + " last_gen_audio = AudioSegment.from_file(combined_audio_path)\n", + " last_5_seconds = last_gen_audio[-5000:] # Duration is in milliseconds\n", + " context_path = os.path.join(out_dir, \"context.wav\")\n", + " last_5_seconds.export(context_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1368c380", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "943a791f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 3791bcdf6e301d8ae0656d6c26670a22436e9489 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Tue, 25 Mar 2025 13:37:07 -0700 Subject: [PATCH 009/113] [bugfix][magpietts] replace pytorch_lightning with lightning.pytorch to make the recipe working with PTL 1.9+. (#47) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- examples/tts/magpietts.py | 2 +- nemo/collections/tts/models/magpietts.py | 4 +--- .../tts/models/magpietts_preference_optimization.py | 4 ++-- scripts/magpietts/codec_extraction.py | 10 +++++----- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index 7f745ced2b6e..2521ac9a2248 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf, open_dict from nemo.collections.tts.models import MagpieTTS_Model, MagpieTTS_ModelInference, MagpieTTS_Model_PrefDataGen, MagpieTTS_Model_OfflinePO, MagpieTTS_Model_OnlinePO diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 508734b62dba..466fbbef5e6d 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import json import os import random @@ -25,7 +24,7 @@ import torch from hydra.utils import instantiate from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from pytorch_lightning.loggers import TensorBoardLogger from torch import nn from torch.utils.data import get_worker_info @@ -35,7 +34,6 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer -from nemo.collections.tts.data.text_to_speech_dataset_lhotse import build_lhotse_dataloader, MagpieTTSLhotseDataset from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index f6e9c56eebee..aad60f922360 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -1,7 +1,7 @@ import numpy as np import torch from omegaconf import DictConfig -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from torch import nn import os import json @@ -812,4 +812,4 @@ def transcribe_with_whisper(audio_filepath, language, whisper_processor, whisper predicted_ids = whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) result = transcription[0] - return result \ No newline at end of file + return result diff --git a/scripts/magpietts/codec_extraction.py b/scripts/magpietts/codec_extraction.py index 1cfd90d798fa..a4a827f76089 100644 --- a/scripts/magpietts/codec_extraction.py +++ b/scripts/magpietts/codec_extraction.py @@ -1,14 +1,14 @@ import json import torch from torch.utils.data import Dataset, DataLoader -import pytorch_lightning as pl -from pytorch_lightning import Trainer -from pytorch_lightning.strategies import DDPStrategy +import lightning.pytorch as pl +from lightning.pytorch import Trainer +from lightning.pytorch.strategies import DDPStrategy from nemo.collections.tts.models import AudioCodecModel import os from nemo.collections.asr.parts.preprocessing.segment import AudioSegment import argparse -from pytorch_lightning.utilities.rank_zero import rank_zero_only +from lightning.pytorch.utilities import rank_zero_only class AudioDataset(Dataset): def __init__(self, file_lists, base_audio_dirs, dataset_names, out_dir, sample_rate=22050, pad_multiple=1024): @@ -252,4 +252,4 @@ def prepare_directories(base_save_dir, codec_model_name, manifests, audio_base_d if torch.distributed.is_initialized(): torch.distributed.barrier() - update_manifests(manifests, save_dir, dataset_names, args.codec_model_name) \ No newline at end of file + update_manifests(manifests, save_dir, dataset_names, args.codec_model_name) From 6ebde2597488806a240025f346156bec347bfafd Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Tue, 25 Mar 2025 13:37:16 -0700 Subject: [PATCH 010/113] [magpietts] enable pin_memory to enable fast data transfer to GPU. (#48) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- examples/tts/conf/magpietts/magpietts_en.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index c46ab208dfe9..6850027be04f 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -92,6 +92,7 @@ model: batch_size: ${batch_size} num_workers: 4 drop_last: true + pin_memory: true validation_ds: dataset: @@ -105,6 +106,7 @@ model: dataloader_params: batch_size: ${batch_size} num_workers: 0 + pin_memory: true encoder: n_layers: 6 From bc1bcfbb92749b488b7497e7d267f13f2c7f435e Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Tue, 25 Mar 2025 13:37:23 -0700 Subject: [PATCH 011/113] [bugfix][tts_dataset] feature_dir is not a required key for magpietts. modify and make it optional. (#49) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- nemo/collections/tts/data/text_to_speech_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 83b1145673b0..ba39b2230410 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -196,7 +196,7 @@ def _preprocess_manifest( dataset_name=dataset_name, manifest_entry=entry, audio_dir=Path(dataset.audio_dir), - feature_dir=Path(dataset.feature_dir), + feature_dir=Path(dataset.feature_dir) if dataset.feature_dir is not None else None, text=text, speaker=speaker, speaker_index=speaker_index, From ed60df4949604e2ed466759a7a873b95e835b084 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Sat, 29 Mar 2025 09:32:20 -0700 Subject: [PATCH 012/113] [magpietts] added WandbLogger support and restructure TensorBoardLogger. (#52) * structured both loggers for train/val/test. * enable `resume` param to ensure the resumed training logs being merged on the previous run id. * removed `tb_logger` func. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- examples/tts/conf/magpietts/magpietts_en.yaml | 1 + nemo/collections/tts/models/magpietts.py | 158 +++++++++++------- 2 files changed, 99 insertions(+), 60 deletions(-) diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index 6850027be04f..cc0694f66283 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -186,6 +186,7 @@ exp_manager: wandb_logger_kwargs: name: null project: null + resume: true create_checkpoint_callback: true checkpoint_callback_params: monitor: val_loss diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 466fbbef5e6d..189b1f25c7e3 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -23,14 +23,13 @@ import soundfile as sf import torch from hydra.utils import instantiate -from omegaconf import DictConfig, open_dict from lightning.pytorch import Trainer -from pytorch_lightning.loggers import TensorBoardLogger +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from omegaconf import DictConfig, open_dict from torch import nn from torch.utils.data import get_worker_info from transformers import AutoTokenizer, T5Tokenizer - import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer @@ -38,7 +37,11 @@ from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 from nemo.collections.tts.modules.aligner import AlignmentEncoder -from nemo.collections.tts.parts.utils.helpers import binarize_attention_parallel, get_mask_from_lengths, plot_alignment_to_numpy +from nemo.collections.tts.parts.utils.helpers import ( + binarize_attention_parallel, + get_mask_from_lengths, + plot_alignment_to_numpy, +) from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo @@ -153,8 +156,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): ) # Changing these to make them different from target audio bos and eos self.context_audio_eos_id = cfg.num_audio_tokens_per_codebook - 3 - self._tb_logger = None - self.pad_context_text_to_max_duration = self.model_type == 'decoder_context_tts' self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False) @@ -277,19 +278,6 @@ def _setup_tokenizers(self, cfg, mode='test'): ) return tokenizer, text_conditioning_tokenizer - @property - def tb_logger(self): - if self._tb_logger is None: - if self.logger is None and self.logger.experiment is None: - return None - tb_logger = self.logger.experiment - for logger in self.trainer.loggers: - if isinstance(logger, TensorBoardLogger): - tb_logger = logger.experiment - break - self._tb_logger = tb_logger - return self._tb_logger - def audio_to_codes(self, audio, audio_len, audio_type='target'): # audio: (B, T) # audio_len: (B,) @@ -521,20 +509,34 @@ def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens with torch.no_grad(): attention_prob_matrix = torch.cat(attention_prob_matrix, dim=1) # (B, C, audio_timesteps, text_timesteps) attention_prob_matrix_mean = attention_prob_matrix.mean(dim=1) # (B, audio_timesteps, text_timesteps) + + images = list() for idx in range(min(3, attention_prob_matrix_mean.size(0))): item_attn_matrix = attention_prob_matrix_mean[idx][ dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx] ] item_attn_matrix = item_attn_matrix.detach().cpu().numpy() - attn_np = plot_alignment_to_numpy(item_attn_matrix.T) - self.tb_logger.add_image( - f'{prefix}attention_matrix_{idx}', - attn_np, - global_step=self.global_step, - dataformats="HWC", + images.append(plot_alignment_to_numpy(item_attn_matrix.T)) + + if isinstance(self.logger, WandbLogger) and HAVE_WANDB: + self.logger.log_image( + key=f"Image/{prefix}/attention_matrix", + images=images, + step=self.global_step, + caption=[f"Example_{idx}" for idx in range(len(images))], ) + elif isinstance(self.logger, TensorBoardLogger): + for idx, img in enumerate(images): + self.logger.experiment.add_image( + f'{prefix}/attention_matrix/Example_{idx}', + img, + global_step=self.global_step, + dataformats="HWC", + ) + else: + ValueError(f"Invalid logger: {self.logger}") - def log_train_val_example( + def log_train_val_audio_example( self, logits, target_audio_codes, @@ -545,36 +547,60 @@ def log_train_val_example( pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens_target) pred_audio, pred_audio_lens = self.codes_to_audio(pred_audio_codes, audio_codes_lens_target) target_audio, target_audio_lens = self.codes_to_audio(target_audio_codes, audio_codes_lens_target) + context_audio, context_audio_lens = None, None if context_audio_codes is not None and context_audio_codes.shape[2] > 3: # > 3 ensures, it is a valid context audio tensor (and not dummy tensor used in text context) context_audio, context_audio_lens = self.codes_to_audio(context_audio_codes, context_audio_codes_lens) + for idx in range(min(3, pred_audio.size(0))): pred_audio_np = pred_audio[idx].float().detach().cpu().numpy() target_audio_np = target_audio[idx].float().detach().cpu().numpy() pred_audio_np = pred_audio_np[: pred_audio_lens[idx]] target_audio_np = target_audio_np[: target_audio_lens[idx]] - self.tb_logger.add_audio( - f'pred_audio_{idx}', - pred_audio_np, - global_step=self.global_step, - sample_rate=self.cfg.sample_rate, - ) - self.tb_logger.add_audio( - f'target_audio_{idx}', - target_audio_np, - global_step=self.global_step, - sample_rate=self.cfg.sample_rate, - ) + context_audio_np = None if context_audio is not None: context_audio_np = context_audio[idx].float().detach().cpu().numpy() context_audio_np = context_audio_np[: context_audio_lens[idx]] - self.tb_logger.add_audio( - f'context_audio_{idx}', - context_audio_np, + + if isinstance(self.logger, WandbLogger) and HAVE_WANDB: + if context_audio_np is not None: + audios_np = [context_audio_np] + captions = ["context"] + else: + audios_np = list() + captions = list() + audios_np = audios_np + [pred_audio_np, target_audio_np] + captions = captions + ["prediction", "target"] + self.logger.log_audio( + key=f"Audio/Example_{idx}", + audios=audios_np, + step=self.global_step, + sample_rate=[self.cfg.sample_rate] * len(audios_np), + caption=captions, + ) + elif isinstance(self.logger, TensorBoardLogger): + if context_audio_np is not None: + self.logger.experiment.add_audio( + f'Example_{idx}/context', + context_audio_np, + global_step=self.global_step, + sample_rate=self.cfg.sample_rate, + ) + self.logger.experiment.add_audio( + f'Example_{idx}/prediction', + pred_audio_np, + global_step=self.global_step, + sample_rate=self.cfg.sample_rate, + ) + self.logger.experiment.add_audio( + f'Example_{idx}/target', + target_audio_np, global_step=self.global_step, sample_rate=self.cfg.sample_rate, ) + else: + ValueError(f"Invalid logger: {self.logger}") def scale_prior(self, prior, global_step): if prior is None: @@ -1007,17 +1033,17 @@ def training_step(self, batch, batch_idx): batch_output = self.process_batch(batch) loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] - self.log('train_codebook_loss', codebook_loss, prog_bar=True, sync_dist=True) + self.log('train/codebook_loss', codebook_loss, prog_bar=True, sync_dist=True) if self.cfg.get('cfg_unconditional_prob', 0.0) == 0.0: # Only log alignment loss when not using cfg to avoid sync issues when # alignment loss is None on some ranks alignment_loss = batch_output['alignment_loss'] if alignment_loss is not None: - self.log('train_alignment_loss', alignment_loss, prog_bar=True, sync_dist=True) - self.log('train_loss', loss, prog_bar=True, sync_dist=True) + self.log('train/alignment_loss', alignment_loss, prog_bar=True, sync_dist=True) + self.log('train/loss', loss, prog_bar=True, sync_dist=True) local_transformer_loss = batch_output['local_transformer_loss'] if local_transformer_loss is not None: - self.log('train_local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True) + self.log('train/local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True) return loss @@ -1041,7 +1067,7 @@ def validation_step(self, batch, batch_idx): aligner_encoder_loss = torch.tensor(0.0, device=loss.device) if batch_idx == 0 and self.global_rank == 0: - self.log_train_val_example( + self.log_train_val_audio_example( logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens ) if ( @@ -1055,19 +1081,19 @@ def validation_step(self, batch, batch_idx): cross_attention_probs, audio_codes_lens_target, text_lens, - prefix="val_", + prefix="val", dec_context_size=dec_context_size, ) for layer_idx in self.transcript_decoder_layers: cross_attention_probs = [ attn_info[layer_idx]['cross_attn_probabilities'][0] ] - self.log_attention_probs(cross_attention_probs, audio_codes_lens_target, text_lens, prefix=f"val_layer_{layer_idx}_", dec_context_size=dec_context_size) + self.log_attention_probs(cross_attention_probs, audio_codes_lens_target, text_lens, prefix=f"val/layer_{layer_idx}", dec_context_size=dec_context_size) if batch_output['aligner_attn_soft'] is not None: self.log_attention_probs( [batch_output['aligner_attn_soft']], audio_codes_lens_target, text_lens, - prefix=f"val_aligner_encoder_attn_", + prefix=f"val/aligner_encoder_attn", ) if batch_output['aligner_attn_hard'] is not None: @@ -1075,7 +1101,7 @@ def validation_step(self, batch, batch_idx): [batch_output['aligner_attn_hard'].unsqueeze(1)], audio_codes_lens_target, text_lens, - prefix=f"val_aligner_encoder_attn_hard_", + prefix=f"val/aligner_encoder_attn_hard", ) local_transformer_loss = batch_output['local_transformer_loss'] @@ -1447,12 +1473,24 @@ def test_step(self, batch, batch_idx): predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] item_idx = batch_idx * test_dl_batch_size + idx - self.tb_logger.add_audio( - 'predicted_audio', - predicted_audio_np, - global_step=item_idx, - sample_rate=self.cfg.sample_rate, - ) + + if isinstance(self.logger, WandbLogger) and HAVE_WANDB: + log_dict = { + f"test/predicted_audio": wandb.Audio( + predicted_audio_np, sample_rate=self.cfg.sample_rate, caption=f"Predicted Audio" + ), + } + self.logger.experiment.log(log_dict, step=item_idx) + elif isinstance(self.logger, TensorBoardLogger): + self.logger.experiment.add_audio( + 'test/predicted_audio', + predicted_audio_np, + global_step=item_idx, + sample_rate=self.cfg.sample_rate, + ) + else: + ValueError(f"Invalid logger: {self.logger}") + # Save the predicted audio log_dir = self.logger.log_dir audio_dir = os.path.join(log_dir, 'audios') @@ -1468,12 +1506,12 @@ def on_validation_epoch_end(self): val_alignment_loss = collect("val_alignment_loss") val_aligner_encoder_loss = collect("val_aligner_encoder_loss") self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) - self.log("val_codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) - self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) - self.log("val_aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True) + self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) + self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) + self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True) if self.cfg.get('use_local_transformer', False): val_local_transformer_loss = collect("val_local_transformer_loss") - self.log("val_local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True) + self.log("val/local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True) self.validation_step_outputs.clear() # free memory def get_dataset(self, cfg, dataset_type): From 38ff49a294bf7cbeeec075b40869eab34b37c5c0 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Tue, 1 Apr 2025 06:40:31 -0700 Subject: [PATCH 013/113] [magpietts] adhere CapWords convention for model class names. (#50) * [magpietts] minor fix for the usage of freezing a model. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * fixed a typo. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestions from code review --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Jason --- examples/tts/magpietts.py | 18 +++++--- .../tts/data/text_to_speech_dataset.py | 4 +- nemo/collections/tts/models/__init__.py | 18 ++++---- nemo/collections/tts/models/magpietts.py | 37 ++++++----------- .../magpietts_preference_optimization.py | 41 ++++++++++--------- scripts/magpietts/infer_and_evaluate.py | 29 +++++++------ 6 files changed, 74 insertions(+), 73 deletions(-) diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index 2521ac9a2248..a89761207e2b 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -15,7 +15,13 @@ import lightning.pytorch as pl from omegaconf import OmegaConf, open_dict -from nemo.collections.tts.models import MagpieTTS_Model, MagpieTTS_ModelInference, MagpieTTS_Model_PrefDataGen, MagpieTTS_Model_OfflinePO, MagpieTTS_Model_OnlinePO +from nemo.collections.tts.models import ( + MagpieTTSModel, + MagpieTTSModelInference, + MagpieTTSModelOfflinePO, + MagpieTTSModelOnlinePO, + MagpieTTSModelPrefDataGen, +) from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @@ -33,21 +39,21 @@ def main(cfg): exp_manager(trainer, cfg.get("exp_manager", None)) if cfg.get('mode', 'train') == 'train': - model = MagpieTTS_Model(cfg=cfg.model, trainer=trainer) + model = MagpieTTSModel(cfg=cfg.model, trainer=trainer) elif cfg.get('mode', 'train') == 'dpo_train': model_cfg = cfg.model with open_dict(model_cfg): model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt - model = MagpieTTS_Model_OfflinePO(cfg=model_cfg, trainer=trainer) + model = MagpieTTSModelOfflinePO(cfg=model_cfg, trainer=trainer) elif cfg.get('mode', 'train') == 'onlinepo_train': model_cfg = cfg.model with open_dict(model_cfg): model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt - model = MagpieTTS_Model_OnlinePO(cfg=model_cfg, trainer=trainer) + model = MagpieTTSModelOnlinePO(cfg=model_cfg, trainer=trainer) elif cfg.get('mode', 'train') == 'test': - model = MagpieTTS_ModelInference(cfg=cfg.model, trainer=trainer) + model = MagpieTTSModelInference(cfg=cfg.model, trainer=trainer) # elif cfg.get('mode', 'train') == 'test': - # model = MagpieTTS_Model_PrefDataGen(cfg=cfg.model, trainer=trainer) + # model = MagpieTTSModelPrefDataGen(cfg=cfg.model, trainer=trainer) else: raise NotImplementedError(f"Only train, dpo_train and test modes are supported. Got {cfg.mode}") diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index ba39b2230410..e3e7363ee892 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -445,9 +445,9 @@ def __getitem__(self, index): audio_codes_path = data.manifest_entry['target_audio_codes_path'] audio_codes = torch.load(audio_codes_path).long() # (C, T) spec_len = audio_codes.shape[1] + 1 # +1 for EOS - auidio_bos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_bos_id, dtype=audio_codes.dtype) + audio_bos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_bos_id, dtype=audio_codes.dtype) audio_eos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_eos_id, dtype=audio_codes.dtype) - audio_codes = torch.cat([auidio_bos_tensor, audio_codes, audio_eos_tensor], dim=1) + audio_codes = torch.cat([audio_bos_tensor, audio_codes, audio_eos_tensor], dim=1) audio_codes_len = audio_codes.shape[1] example['audio_codes'] = audio_codes example['audio_codes_len'] = audio_codes_len diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index 203a662f4246..b9491c56d1a0 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -17,8 +17,12 @@ from nemo.collections.tts.models.fastpitch import FastPitchModel from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL from nemo.collections.tts.models.hifigan import HifiGanModel -from nemo.collections.tts.models.magpietts import MagpieTTS_Model, MagpieTTS_ModelInference -from nemo.collections.tts.models.magpietts_preference_optimization import MagpieTTS_Model_PrefDataGen, MagpieTTS_Model_OfflinePO, MagpieTTS_Model_OnlinePO +from nemo.collections.tts.models.magpietts import MagpieTTSModel, MagpieTTSModelInference +from nemo.collections.tts.models.magpietts_preference_optimization import ( + MagpieTTSModelOfflinePO, + MagpieTTSModelOnlinePO, + MagpieTTSModelPrefDataGen, +) from nemo.collections.tts.models.mixer_tts import MixerTTSModel from nemo.collections.tts.models.radtts import RadTTSModel from nemo.collections.tts.models.spectrogram_enhancer import SpectrogramEnhancerModel @@ -40,11 +44,11 @@ "MelPsuedoInverseModel", "MixerTTSModel", "RadTTSModel", - "MagpieTTS_Model", - "MagpieTTS_ModelInference", - "MagpieTTS_Model_PrefDataGen", - "MagpieTTS_Model_OfflinePO", - "MagpieTTS_Model_OnlinePO", + "MagpieTTSModel", + "MagpieTTSModelInference", + "MagpieTTSModelPrefDataGen", + "MagpieTTSModelOfflinePO", + "MagpieTTSModelOnlinePO", "Tacotron2Model", "TwoStagesModel", "UnivNetModel", diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 189b1f25c7e3..edb1e7760b2f 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -97,7 +97,7 @@ def worker_init_fn(worker_id): dataset.text_conditioning_tokenizer = text_conditioning_tokenizer -class MagpieTTS_Model(ModelPT): +class MagpieTTSModel(ModelPT): """ Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context audio/text @@ -206,17 +206,14 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False) # del codec discriminator to free memory del codec_model.discriminator - codec_model.eval() - self.freeze_model(codec_model) self._codec_model = codec_model + self._codec_model.freeze() #Lightning does requires_grad = False and self.eval() if self.model_type == 'single_encoder_sv_tts': - speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + self._speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( model_name='titanet_large' ) - speaker_verification_model.eval() - self.freeze_model(speaker_verification_model) - self._speaker_verification_model = speaker_verification_model + self._speaker_verification_model.freeze() #Lightning does requires_grad = False and self.eval() self.speaker_projection_layer = nn.Linear(cfg.speaker_emb_dim, cfg.embedding_dim) self.transcript_decoder_layers = [ idx for idx in range(cfg.decoder.n_layers) @@ -253,10 +250,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if alignment_encoder_loss_scale > 0.0: self.alignment_encoder_loss = ForwardSumLoss(loss_scale=alignment_encoder_loss_scale) - def freeze_model(self, model): - for param in model.parameters(): - param.requires_grad = False - def state_dict(self, destination=None, prefix='', keep_vars=False): if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} @@ -293,7 +286,7 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): self._codec_model.eval() with torch.no_grad(): codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len) - # Add a timestep to begining and end of codes tensor + # Add a timestep to beginning and end of codes tensor bos_tensor = torch.full( (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device ) @@ -348,7 +341,7 @@ def get_speaker_embeddings(self, audio_16khz, audio_len_16khz): def compute_local_transformer_logits(self, dec_out, audio_codes_target): """ - Loss from the autoregrssive codebook predictor (used per frame) + Loss from the autoregressive codebook predictor (used per frame) """ # dec_out: (B, T', E) # audio_codes: (B, C, T') @@ -913,7 +906,7 @@ def process_batch(self, batch, mode="train"): max_codebook_val = self.cfg.get('dec_random_input_max', self.cfg.num_audio_tokens_per_codebook) # @pneekhara: Keeping dec_random_input_max configurable since num_audio_tokens_per_codebook usually has padding tokens # which can cause errors when doing codes_to_audio for audio_codes_input. We are not currently calling codes_to_audio on - # audio_codes_input so should not matter if we dont supply dec_random_input_max. + # audio_codes_input so should not matter if we don't supply dec_random_input_max. random_audio_tokens = torch.randint( 0, max_codebook_val, audio_codes_input.size(), device=audio_codes_input.device ) @@ -1539,7 +1532,7 @@ def get_dataset(self, cfg, dataset_type): ) # This will be used in worker_init_fn for instantiating tokenizer return dataset - def _setup_train_dataloader(self, cfg): + def setup_training_data(self, cfg): dataset = self.get_dataset(cfg, dataset_type='train') sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) persistent_workers = True @@ -1547,7 +1540,7 @@ def _setup_train_dataloader(self, cfg): persistent_workers = False # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) dataset.text_tokenizer, dataset.text_conditioning_tokenizer = self._setup_tokenizers(self.cfg) - data_loader = torch.utils.data.DataLoader( + self._train_dl = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, @@ -1555,9 +1548,8 @@ def _setup_train_dataloader(self, cfg): worker_init_fn=worker_init_fn, persistent_workers=persistent_workers, ) - return data_loader - def _setup_test_dataloader(self, cfg): + def _setup_test_dataloader(self, cfg) -> torch.utils.data.DataLoader: dataset = self.get_dataset(cfg, dataset_type='test') persistent_workers = True if cfg.dataloader_params.num_workers == 0: @@ -1574,9 +1566,6 @@ def _setup_test_dataloader(self, cfg): ) return data_loader - def setup_training_data(self, cfg): - self._train_dl = self._setup_train_dataloader(cfg) - def setup_validation_data(self, cfg): self._validation_dl = self._setup_test_dataloader(cfg) @@ -1588,8 +1577,8 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: return [] -class MagpieTTS_ModelInference(MagpieTTS_Model): - """Small override of MagpieTTS_Model for parallel multi-GPU inference and metrics calculation. +class MagpieTTSModelInference(MagpieTTSModel): + """Small override of MagpieTTSModel for parallel multi-GPU inference and metrics calculation. This class is used in 'test' mode and leverages trainer.test() for multi-GPU/multi-node inference. Saves the predicted audio files and logs the CER/WER metrics as individual json files for each audio. """ @@ -1601,13 +1590,11 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): model_name="nvidia/parakeet-tdt-1.1b" ) self.eval_asr_model.freeze() - self.eval_asr_model.eval() self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( model_name='titanet_large' ) self.eval_speaker_verification_model.freeze() - self.eval_speaker_verification_model.eval() if cfg.get('load_whisper_model', False): from transformers import WhisperForConditionalGeneration, WhisperProcessor diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index aad60f922360..3cfa03284cad 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -1,20 +1,21 @@ +import copy +import json +import os +import random +import string + +import librosa import numpy as np +import soundfile as sf import torch -from omegaconf import DictConfig from lightning.pytorch import Trainer -from torch import nn -import os -import json -from nemo.utils import logging +from omegaconf import DictConfig, open_dict + import nemo.collections.asr as nemo_asr -import soundfile as sf -import librosa -import copy -from omegaconf import open_dict -import string from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors -import random +from nemo.utils import logging + try: import torchaudio from torchaudio.pipelines import SQUIM_OBJECTIVE @@ -22,10 +23,10 @@ except ImportError: HAVE_TORCHAUDIO = False -from nemo.collections.tts.models import MagpieTTS_Model +from nemo.collections.tts.models import MagpieTTSModel -class MagpieTTS_Model_PrefDataGen(MagpieTTS_Model): +class MagpieTTSModelPrefDataGen(MagpieTTSModel): """Small override to save inference metrics, used for datagen in Offline PO""" def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) @@ -39,7 +40,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.eval_speaker_verification_model.eval() if cfg.get('load_whisper_model', False): - from transformers import WhisperProcessor, WhisperForConditionalGeneration + from transformers import WhisperForConditionalGeneration, WhisperProcessor self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") self.whisper_model.eval() @@ -138,7 +139,7 @@ def test_step(self, batch, batch_idx): with open(os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w') as f: json.dump(item_metrics, f) -class MagpieTTS_Model_OfflinePO(MagpieTTS_Model): +class MagpieTTSModelOfflinePO(MagpieTTSModel): def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) # Copy cfg @@ -146,7 +147,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): with open_dict(ref_model_cfg): ref_model_cfg.train_ds = None ref_model_cfg.validation_ds = None - self._reference_model = MagpieTTS_Model(cfg=ref_model_cfg) + self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) print("Loading reference model from checkpoint") self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) self.freeze_model(self._reference_model) @@ -371,7 +372,7 @@ def collect(key): self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) self.validation_step_outputs.clear() -class MagpieTTS_Model_OnlinePO(MagpieTTS_Model): +class MagpieTTSModelOnlinePO(MagpieTTSModel): def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) # Copy cfg @@ -382,7 +383,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.reference_free = self.cfg.get('reference_free', False) # True means we dont use the reference model if not self.reference_free: - self._reference_model = MagpieTTS_Model(cfg=ref_model_cfg) + self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) print("Loading reference model from checkpoint") self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) self.freeze_model(self._reference_model) @@ -395,7 +396,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.eval_asr_model.freeze() self.eval_asr_model.eval() elif cfg.get('reward_asr_model', "nemo") == "whisper": - from transformers import WhisperProcessor, WhisperForConditionalGeneration + from transformers import WhisperForConditionalGeneration, WhisperProcessor self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") self.whisper_model.eval() @@ -407,7 +408,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.eval_speaker_verification_model.eval() if cfg.get('load_whisper_model', False): - from transformers import WhisperProcessor, WhisperForConditionalGeneration + from transformers import WhisperForConditionalGeneration, WhisperProcessor self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") self.whisper_model.eval() diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index a565b73267a7..634bafda5e5a 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -1,21 +1,24 @@ -from nemo.collections.tts.models import MagpieTTS_Model -from nemo.collections.tts.data.text_to_speech_dataset import MagpieTTSDataset -from omegaconf.omegaconf import OmegaConf, open_dict -import os +import argparse +import copy import glob -import torch -import soundfile as sf -import evaluate_generated_audio -import evalset_config import json -import argparse +import os +import shutil + +import evalset_config +import evaluate_generated_audio import numpy as np import scipy.stats as stats -import copy -import shutil -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +import soundfile as sf +import torch +from omegaconf.omegaconf import OmegaConf, open_dict from PIL import Image +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.tts.data.text_to_speech_dataset import MagpieTTSDataset +from nemo.collections.tts.models import MagpieTTSModel + + def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0.90): metrics = {} for key in metric_keys: @@ -65,7 +68,7 @@ def run_inference( model_cfg.validation_ds = None - model = MagpieTTS_Model(cfg=model_cfg) + model = MagpieTTSModel(cfg=model_cfg) model.use_kv_cache_for_inference = True # Load weights from checkpoint file From 21bf9c94a741665bd37b5eb157230309e9849fc3 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 8 Apr 2025 12:32:35 -0400 Subject: [PATCH 014/113] undo unintended change in #45 (#53) Signed-off-by: Jason --- scripts/magpietts/evalset_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index 77caaa788c00..b44c1601456a 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -1,8 +1,8 @@ dataset_meta_info = { 'vctk': { - 'manifest_path' : '/home/jasoli/data_prime/manifests/rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json', - 'audio_dir' : '/home/jasoli/data_prime/RIVA-TTS/en/', - 'feature_dir' : '/home/jasoli/data_prime/codecs', + 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths.json', + 'audio_dir' : '/datap/misc/Datasets/VCTK-Corpus', + 'feature_dir' : '/datap/misc/Datasets/VCTK-Corpus', }, 'riva_challenging': { 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/challengingLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json', From 790e8f990d7795224badcd8822628582f77c90a0 Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Wed, 9 Apr 2025 12:34:52 -0700 Subject: [PATCH 015/113] Autocasting diasbled for codec model and Making prior window strict (#56) * trainer import fix for new pytorch lightning Signed-off-by: Paarth Neekhara * handle strict prior window correctly Signed-off-by: Paarth Neekhara * disable autocasting codec model and making prior window strict Signed-off-by: Paarth Neekhara --------- Signed-off-by: Paarth Neekhara --- examples/tts/conf/magpietts/magpietts_en.yaml | 1 - .../magpietts/magpietts_inference_en.yaml | 1 - .../magpietts_inference_multilingual_v1.yaml | 1 - .../tts/conf/magpietts/magpietts_lhotse.yaml | 1 - .../magpietts/magpietts_multilingual_v1.yaml | 1 - nemo/collections/tts/models/magpietts.py | 54 ++++++++++--------- .../tts/modules/transformer_2501.py | 28 ++++++---- scripts/magpietts/codec_extraction.py | 5 +- 8 files changed, 49 insertions(+), 43 deletions(-) diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index cc0694f66283..12deab788047 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -152,7 +152,6 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true - prior_eps: 1e-8 optim: _target_: torch.optim.Adam diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_inference_en.yaml index b1b042f8663d..38b383fcfc58 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_en.yaml @@ -142,7 +142,6 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true - prior_eps: 1e-8 optim: _target_: torch.optim.Adam diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml index 74ca87fea48c..5cbaf589ed1b 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml @@ -190,7 +190,6 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true - prior_eps: 1e-8 optim: _target_: torch.optim.Adam diff --git a/examples/tts/conf/magpietts/magpietts_lhotse.yaml b/examples/tts/conf/magpietts/magpietts_lhotse.yaml index 17b3597a9a3f..ee96e8307d78 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse.yaml @@ -221,7 +221,6 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true - prior_eps: 1e-8 optim: _target_: torch.optim.Adam diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index 9f3a4d14f03d..e262833ba907 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -198,7 +198,6 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true - prior_eps: 1e-8 optim: _target_: torch.optim.Adam diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index edb1e7760b2f..b82c07a64e68 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -284,37 +284,39 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): raise ValueError(f"Received audio_type of {audio_type}. Must be `target` or `context`") self._codec_model.eval() - with torch.no_grad(): - codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len) - # Add a timestep to beginning and end of codes tensor - bos_tensor = torch.full( - (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device - ) - pad_tensor = torch.full( - (codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device - ) # 0 is the padding token in the audio codebook - codes = torch.cat([bos_tensor, codes, pad_tensor], dim=-1) - # codes: (B, C, T') - # codes_len: (B,) - for idx in range(codes.size(0)): - codes[idx, :, codes_len[idx] + 1] = audio_eos_id - codes_len = codes_len + 2 - - return codes.long(), codes_len.long() + with torch.cuda.amp.autocast(enabled=False): + with torch.no_grad(): + codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len) + # Add a timestep to begining and end of codes tensor + bos_tensor = torch.full( + (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device + ) + pad_tensor = torch.full( + (codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device + ) # 0 is the padding token in the audio codebook + codes = torch.cat([bos_tensor, codes, pad_tensor], dim=-1) + # codes: (B, C, T') + # codes_len: (B,) + for idx in range(codes.size(0)): + codes[idx, :, codes_len[idx] + 1] = audio_eos_id + codes_len = codes_len + 2 + + return codes.long(), codes_len.long() def codes_to_audio(self, codes, codes_len): # codes: (B, C, T') # codes_len: (B,) self._codec_model.eval() - with torch.no_grad(): - # Replace eos and bos tokens with padding in codes tensor - codes[codes == self.audio_bos_id] = 0 # zero is the padding token in the audio codebook - codes[codes == self.audio_eos_id] = 0 - # self.additional_models['codec'] = self.additional_models['codec'].to(codes.device) - audio, audio_len = self._codec_model.decode(tokens=codes, tokens_len=codes_len) - # audio: (B, T) - # audio_len: (B,) - return audio, audio_len + with torch.cuda.amp.autocast(enabled=False): + with torch.no_grad(): + # Replace eos and bos tokens with padding in codes tensor + codes[codes == self.audio_bos_id] = 0 # zero is the padding token in the audio codebook + codes[codes == self.audio_eos_id] = 0 + # self.additional_models['codec'] = self.additional_models['codec'].to(codes.device) + audio, audio_len = self._codec_model.decode(tokens=codes, tokens_len=codes_len) + # audio: (B, T) + # audio_len: (B,) + return audio, audio_len def embed_audio_tokens(self, audio_tokens): # audio_tokens: (B, C, T') diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index 2d56b1c52465..b72cef6d3f92 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -227,11 +227,16 @@ def attn_naive( # attn_prior or square mask or vanilla attention if attn_prior is not None: - eps = self.prior_eps + eps = 1e-8 attn_prior = attn_prior[:, :T] # trim for inference - attn_prior = torch.log(attn_prior + eps) - attn_prior = attn_prior[:, None].repeat(1, self.n_heads, 1, 1) - attn_score_log = F.log_softmax(attn_score, dim=-1) + attn_prior + attn_prior = attn_prior[:, None] + attn_prior_log = torch.log(attn_prior + eps) + attn_score_log = F.log_softmax(attn_score, dim=-1) + attn_prior_log + if self.make_prior_window_strict: + # Make sure attention scores are lowest (eps) where prior is zero. + min_score = torch.log(torch.tensor(eps)).to(attn_score_log.device) + attn_score_log = attn_score_log.masked_fill(attn_prior == 0, min_score) # Wherever prior is zero, set scores to eps. + attn_score_log = torch.clamp(attn_score_log, min=min_score) # Make sure scores are not less than eps. attn_prob = F.softmax(attn_score_log, dim=-1) else: attn_prob = F.softmax(attn_score, dim=-1) @@ -344,7 +349,7 @@ def __init__( d_model: int, d_memory: int, p_dropout: float, - prior_eps: float = 1e-8, + make_prior_window_strict: bool = False, ): """ Implements CrossAttention. See parent class for forward implementation. Must be non-causal. @@ -354,6 +359,7 @@ def __init__( d_model (int): Dimension of the model. d_memory (int): Dimension of the conditioning / cross-attention input. p_dropout (float): Dropout probability. + make_prior_window_strict (bool): Make attention scores lowest where prior is zero. """ super().__init__( n_heads=n_heads, @@ -365,7 +371,7 @@ def __init__( raise ValueError("d_memory must be provided for cross-attention") self.q_net = torch.nn.Linear(d_model, n_heads * self.d_head, bias=False) self.kv_net = torch.nn.Linear(d_memory, 2 * n_heads * self.d_head, bias=False) - self.prior_eps = prior_eps + self.make_prior_window_strict = make_prior_window_strict def compute_qkv_and_mask( self, @@ -412,7 +418,7 @@ def __init__( apply_norm_to_cond: bool = True, max_length_causal_mask: int = 4096, conv_non_linearity: Callable = torch.nn.GELU(approximate="tanh"), - prior_eps: float = 1e-8, + make_prior_window_strict: bool = False, ): """ One layer of the Transformer. @@ -429,6 +435,7 @@ def __init__( apply_norm_to_cond : Whether to apply normalization to conditioning tensor max_length_causal_mask : Maximum length of causal mask conv_non_linearity : Convolution non-linearity + make_prior_window_strict : Make attention scores lowest where prior is zero. """ super().__init__() self.has_xattn = has_xattn @@ -450,7 +457,7 @@ def __init__( d_model=d_model, d_memory=xa_d_memory, p_dropout=p_dropout, - prior_eps=prior_eps, + make_prior_window_strict=make_prior_window_strict, ) if self.apply_norm_to_cond: @@ -556,7 +563,7 @@ def __init__( max_length_causal_mask: int = 4096, use_learnable_pos_emb: bool = False, conv_non_linearity: Callable = torch.nn.GELU(approximate="tanh"), - prior_eps: float = 1e-8, + make_prior_window_strict: bool = False, ): """ Initializes a stack of transformer layers. Can be used for both encoder and decoder. @@ -579,6 +586,7 @@ def __init__( max_length_causal_mask : Maximum length of causal mask use_learnable_pos_emb : Whether to add a learnable positionable embedding inside the class conv_non_linearity : Convolution non-linearity + make_prior_window_strict : Make attention scores lowest where prior is zero """ if has_xattn and (xa_d_memory is None or xa_n_heads is None): raise ValueError("It requires that `xa_d_memory` and `xa_n_heads` are specified when `has_xattn` is True!") @@ -614,7 +622,7 @@ def __init__( apply_norm_to_cond=apply_norm_to_cond, max_length_causal_mask=max_length_causal_mask, conv_non_linearity=conv_non_linearity, - prior_eps=prior_eps, + make_prior_window_strict=make_prior_window_strict, ) ) diff --git a/scripts/magpietts/codec_extraction.py b/scripts/magpietts/codec_extraction.py index a4a827f76089..83c164a9a6c8 100644 --- a/scripts/magpietts/codec_extraction.py +++ b/scripts/magpietts/codec_extraction.py @@ -86,8 +86,9 @@ def __init__(self, model_path): self.codec_model.eval() def forward(self, batch): - with torch.no_grad(): - codes, codes_lengths = self.codec_model.encode(audio=batch["audios"], audio_len=batch["audio_lengths"]) + with torch.cuda.amp.autocast(enabled=False): + with torch.no_grad(): + codes, codes_lengths = self.codec_model.encode(audio=batch["audios"], audio_len=batch["audio_lengths"]) return codes, codes_lengths def predict_step(self, batch, batch_idx, dataloader_idx=0): From e63b61df926ddff1b3e18b2fbb4c6d1101ad677b Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 11 Apr 2025 11:43:35 -0700 Subject: [PATCH 016/113] [magpie][lhotse] Add lhoste shar dataset prep recipe and lhotse dataloader. (#54) * [magpie][lhotse] added a lhotse dataloader for monologue tts. this is a working recipe with num_workers>0 for training and num_workers=0 for val datasets. Still faced issues when num_workers>0 during validation steps. Investigating rootcauses. * all contents in a batch are obtained correctly, but dtype mismatches. * fix dtype for text tokens and codec codes. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [lhotse_shar_prep] add script to create shar dataset. * with more efficient changes. * bugfix previously the last batch would be dropped if the size is less than the buffer size. this fixes it. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [lhotse_dataloader] clean up commented lines. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [lhotse_dataloader] bugfix to force spawn over fork to address CUDA initialization errors when multiple workers are used during validation. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [lhotse_dataloader] save efforts to set up tokenizer again for training since it has been setup ready during model initialization. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [lhotse_dataloader] switch to setup tokenizer inside __getitem__ to support spawn worker processes. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [magpietts][lhotse] fixed a bug of attatch_tensor which save wrong numpy array. update yaml config Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [magpie][lhotse_config] enforce quadratic_duration if using lhotse dataloader to avoid frequent OOMs. changed yaml name to monologue Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [magpie][example] add LR logger. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * cleanup Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [lhotse_yaml] made changes for yaml config according to comments. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [magpie][lhotse_dataset] added docstring for lhotse dataset Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [magpie][lhotse_dataset] remove yamls Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [magpie][lhotse_dataset] remove Edresson's lhotse implementations, and update yaml name. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [magpie][lhotse_dataset] add a README showing guidance how to create lhotse data Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [magpie][lhotse_dataset] update MonoCut example. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * rename config Signed-off-by: Jason --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Jason Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Jason --- .../tts/conf/magpietts/magpietts_lhotse.yaml | 269 -------- .../conf/magpietts/magpietts_lhotse_en.yaml | 211 ++++++ examples/tts/magpietts.py | 17 +- .../common/data/lhotse/dataloader.py | 4 +- .../text_to_speech/tts_tokenizers.py | 3 + .../tts/data/text_to_speech_dataset.py | 4 +- .../tts/data/text_to_speech_dataset_lhotse.py | 646 +++++++++++------- nemo/collections/tts/models/magpietts.py | 170 ++--- scripts/magpietts/README_lhotse.md | 174 +++++ .../magpietts/convert_nemo_to_lhotse_shar.py | 464 +++++++++++++ 10 files changed, 1378 insertions(+), 584 deletions(-) delete mode 100644 examples/tts/conf/magpietts/magpietts_lhotse.yaml create mode 100644 examples/tts/conf/magpietts/magpietts_lhotse_en.yaml create mode 100644 scripts/magpietts/README_lhotse.md create mode 100644 scripts/magpietts/convert_nemo_to_lhotse_shar.py diff --git a/examples/tts/conf/magpietts/magpietts_lhotse.yaml b/examples/tts/conf/magpietts/magpietts_lhotse.yaml deleted file mode 100644 index ee96e8307d78..000000000000 --- a/examples/tts/conf/magpietts/magpietts_lhotse.yaml +++ /dev/null @@ -1,269 +0,0 @@ -name: Magpie-TTS-Lhotse - -max_steps: ??? -limit_val_batches: ??? -# Adjust batch size based on GPU memory -batch_size: 16 -micro_batch_size: 16 -batch_duration: ??? -eval_batch_size: ??? - -# Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 -train_ds_meta: ??? -val_ds_meta: ??? - -# Modify these values based on your sample rate -sample_rate: 22050 - -phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" -heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" -model: - use_lhotse: true - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts - context_duration_min: 3.0 - context_duration_max: 8.0 - speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids - codec_model_downsample_factor: 1024 - load_cached_codes_if_available: true - prior_scaling_factor: 0.5 - prior_end_step: 12000 - prior_scaledown_start_step: 8000 - indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. - alignment_loss_scale: 0.0 - embedding_dim: 768 - codecmodel_path: ??? - max_steps: ${max_steps} - - sample_rate: ${sample_rate} - - # Alignment encoder parameters, to binarize the prior - # This is used for attention-constrained training and inference - use_alignment_encoder: false - # Below args are only relevant if use_alignment_encoder is true - use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder - alignment_encoder_loss_scale: 1.0 - binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. - binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. - prior_future_context: 2 # Future window of the binarized prior. - prior_past_context: 2 # Past window of the binarized prior. - prior_future_decay: 0.8 # Decay factor for future context - prior_past_decay: 0.5 # Decay factor for past context - binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs - binarized_prior_epsilon: 0.0 - aligner_encoder_train_steps: 50000 - - # Local transformer parameters for autoregressive codebook prediction within a frame - use_local_transformer: false - # Below args are only relevant if use_local_transformer is true - local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 1 - local_transformer_n_heads: 1 - local_transformer_hidden_dim: 256 - - text_tokenizers: - english_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - punct: true - apostrophe: true - pad_with_space: false - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - phoneme_dict: ${phoneme_dict_path} - heteronyms: ${heteronyms_path} - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - - train_ds: - use_lhotse: ${model.use_lhotse} - dataset: - input_cfg: - - type: lhotse_shar - shar_path: /cluster_data/TTS/tts_lhotse_datasets/hifitts_v0/ - weight: 1.0 - tags: - lang: en - s2s: True - tokenizer_names: ["english_phoneme"] - - - type: lhotse_shar - shar_path: /cluster_data/TTS/tts_lhotse_datasets/libri100/ - weight: 1.0 - tags: - lang: en - s2s: True - tokenizer_names: ["english_phoneme"] - - - type: lhotse_shar - shar_path: /cluster_data/TTS/tts_lhotse_datasets/rivaLindyRodney/ - weight: 1.0 - tags: - lang: en - s2s: True - tokenizer_names: ["english_phoneme"] - - - type: lhotse_shar - shar_path: /cluster_data/TTS/tts_lhotse_datasets/libri360/ - weight: 1.0 - tags: - lang: en - s2s: True - tokenizer_names: ["english_phoneme"] - - global_batch_size: ${batch_size} - micro_batch_size: ${micro_batch_size} - batch_size: null - shuffle: True - num_workers: 0 - pin_memory: True - max_seq_length: 512 - min_seq_length: 1 - drop_last: True - # Notably, the data weights are controlled by either bucketing_weights - # or concat_sampling_probabilities depending on the dataset type (tar and - # non-tar). - # See audio_text_qa_dataset.py for details. - concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' - # ASR configs - sample_rate: ${model.sample_rate} - max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset - min_duration: 0.1 - # tarred datasets - is_tarred: false - # tarred_audio_filepaths: null - shuffle_n: 2048 - # bucketing params - bucketing_strategy: "fully_randomized" - bucketing_batch_size: null - use_bucketing: true - use_lhotse: ${model.use_lhotse} - text_field : "text" - seed: 'trng' - batch_duration : ${batch_duration} # 0 - quadratic_duration : 20 - num_buckets : 31 - buffer_size : 10000 - shuffle_buffer_size : 10000 - num_cuts_for_bins_estimate: 10000 - duration_bins: [3.155,3.76,4.27,4.74,5.1935,5.64,6.096,6.588,7.14,7.81,8.28,8.664,9.072,9.57,10.14,10.7335,11.3735,12.09,12.78,13.41,14.01,14.62,15.253375,15.96875,16.71,17.45,18.1335,18.7735,19.4,20.0] - # bucket_duration_bins: [3.155,3.76,4.27,4.74,5.1935,5.64,6.096,6.588,7.14,7.81,8.28,8.664,9.072,9.57,10.14,10.7335,11.3735,12.09,12.78,13.41,14.01,14.62,15.253375,15.96875,16.71,17.45,18.1335,18.7735,19.4,20.0] - - validation_ds: - use_lhotse: ${model.use_lhotse} - dataset: - input_cfg: - - type: lhotse_shar - shar_path: /cluster_data/TTS/tts_lhotse_datasets/LibriTTS_dev_clean/ - weight: 1.0 - tags: - lang: en - s2s: True - tokenizer_names: ["english_phoneme"] - - global_batch_size: ${batch_size} - micro_batch_size: ${micro_batch_size} - shuffle: False - num_workers: 0 - pin_memory: True - drop_last: False - use_bucketing: false - is_tarred: false - batch_size: ${eval_batch_size} - - encoder: - n_layers: 6 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: False - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: false - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - decoder: - n_layers: 12 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: true - xa_d_memory: 768 - xa_n_heads: 12 - is_causal: true - apply_norm_to_cond: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - optim: - _target_: torch.optim.Adam - lr: 2e-4 - betas: [0.8, 0.99] - - sched: - name: ExponentialLR - gamma: 0.998 - -trainer: - num_nodes: 1 - devices: -1 - accelerator: gpu - strategy: ddp_find_unused_parameters_true - precision: 32 - max_epochs: -1 - accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager - log_every_n_steps: 10 - val_check_interval: 500 - limit_train_batches: ${trainer.val_check_interval} - # check_val_every_n_epoch: 10 - benchmark: false - max_steps: ${max_steps} - limit_val_batches: ${limit_val_batches} - use_distributed_sampler: false - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_wandb_logger: false - wandb_logger_kwargs: - name: null - project: null - create_checkpoint_callback: true - checkpoint_callback_params: - monitor: val_loss - mode: min - save_top_k: 5 - save_best_model: true - always_save_nemo: true - resume_if_exists: true - resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_en.yaml new file mode 100644 index 000000000000..f61cfc4c337a --- /dev/null +++ b/examples/tts/conf/magpietts/magpietts_lhotse_en.yaml @@ -0,0 +1,211 @@ +name: MagpieTTS-EN-Lhotse + +quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. +sample_rate: 22_050 + +model: + use_lhotse: true + model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer + use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. + transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. + context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts + context_duration_min: 3.0 + context_duration_max: 5.0 + speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise + num_audio_codebooks: 8 + num_audio_tokens_per_codebook: 2_048 # Keep atleast 2 extra for eos/bos ids + codec_model_downsample_factor: 1_024 + codec_model_name: "21fpsCausalDecoder" + load_cached_codes_if_available: true + prior_scaling_factor: 0.5 + prior_end_step: 12_000 + prior_scaledown_start_step: 8_000 # Prior will always be on before this step. + indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. + alignment_loss_scale: 0.002 + embedding_dim: 768 + codecmodel_path: ??? + sample_rate: ${sample_rate} + cfg_unconditional_prob: 0.1 # enable classifier-free guidance during traing by dropping out conditionals with this probability + + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10_000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50_000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + use_local_transformer: false + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + + text_tokenizers: # Add more languages for multi-lingual TTS + english_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: false + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + + train_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + sample_rate: ${sample_rate} # need to override the default sample rate of 16000 + batch_duration : ??? # in seconds. Adjust based on your GPU memory. + quadratic_duration: ${quadratic_duration} + use_bucketing: true + num_buckets: 10 + bucket_buffer_size: 20_000 + shuffle_buffer_size: 20_000 + num_cuts_for_bins_estimate: 20_000 + shard_seed: "trng" + drop_last: true + shuffle: true + num_workers: 4 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + validation_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + sample_rate: ${sample_rate} + batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. + quadratic_duration: ${quadratic_duration} + use_bucketing: false + force_finite: true + drop_last: false + shuffle: false + num_workers: 4 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + encoder: + n_layers: 6 + d_model: 768 + d_ffn: 3_072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: false + apply_norm_out: true + max_length_causal_mask: 2_048 + use_learnable_pos_emb: true + + context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise + n_layers: 3 + d_model: 768 + d_ffn: 3_072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: false + apply_norm_out: true + max_length_causal_mask: 2_048 + use_learnable_pos_emb: true + + decoder: + n_layers: 12 + d_model: 768 + d_ffn: 3_072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: true + xa_d_memory: 768 + xa_n_heads: 12 + is_causal: true + apply_norm_to_cond: true + apply_norm_out: true + max_length_causal_mask: 2_048 + use_learnable_pos_emb: true + + optim: + _target_: torch.optim.Adam + lr: 2e-4 + betas: [0.8, 0.99] + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: bf16-mixed + max_steps: ??? + accumulate_grad_batches: 1 + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + limit_train_batches: 1_000 + val_check_interval: 1_000 + num_sanity_val_steps: 0 + benchmark: false + use_distributed_sampler: false # required because Lhotse has its own handling + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_wandb_logger: false + wandb_logger_kwargs: + entity: null + project: null + group: null + name: null + resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}-{epoch}' + resume_if_exists: true + resume_ignore_no_checkpoint: true diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index a89761207e2b..c96c55d3dee4 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -13,6 +13,7 @@ # limitations under the License. import lightning.pytorch as pl +import torch.multiprocessing as mp from omegaconf import OmegaConf, open_dict from nemo.collections.tts.models import ( @@ -20,7 +21,6 @@ MagpieTTSModelInference, MagpieTTSModelOfflinePO, MagpieTTSModelOnlinePO, - MagpieTTSModelPrefDataGen, ) from nemo.core.config import hydra_runner from nemo.utils import logging @@ -30,12 +30,21 @@ @hydra_runner(config_path="conf/magpietts", config_name="magpietts_en") def main(cfg): logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg, resolve=True)) - if not cfg.model.get('use_lthose', False): - import torch.multiprocessing as mp - mp.set_start_method("spawn", force=True) + # forcing "spawn" method for multiprocessing over "fork" when choosing multiple + # worker processes for dataloaders. By default, multiprocessing uses "fork" to create + # worker processes, which inherit the memory state of the main process, including its + # already initialized CUDA state. When the worker processes trieds to use + # CUDA, it runs into conflicts with the inherited, now potentially invalid, + # CUDA context, resuling in the CUDA initialization error. When + # num_workers=0, all dataloading happens in the main process, so there is no + # process forking and no CUDA context conflict. When num_workers>0, the standard way + # to fix this is to use "spawn" to create a completely new and clean python process for + # each worker, avoding the problematic CUDA state inheritance. + mp.set_start_method("spawn", force=True) trainer = pl.Trainer(**cfg.trainer) + trainer.callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='step', log_weight_decay=True)) exp_manager(trainer, cfg.get("exp_manager", None)) if cfg.get('mode', 'train') == 'train': diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index a9d9cd7aa81d..fb3c97bc9bdb 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -221,7 +221,7 @@ def get_lhotse_dataloader_from_config( tokenizer=None, ) -> torch.utils.data.DataLoader: """ - Set up a Lhotse training dataloder. + Set up a Lhotse training dataloader. Expects a typical NeMo dataset configuration format, with additional fields: "use_lhotse=True". Some fields in the original NeMo configuration may be ignored. @@ -267,7 +267,7 @@ def get_lhotse_dataloader_from_single_config( tokenizer=None, ) -> torch.utils.data.DataLoader: """ - Set up a Lhotse training dataloder. + Set up a Lhotse training dataloader. Expects a typical NeMo dataset configuration format, with additional fields: "use_lhotse=True". Some fields in the original NeMo configuration may be ignored. diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index 988cb853a9e8..c5509d3ffa30 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -1086,6 +1086,8 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): return [self._token2id[p] for p in ps] +# TODO @xueyang: subclassing from `nemo/collections/common/tokenizers/tokenizer_spec.py::TokenizerSpec`, and/or +# adjust to reuse `nemo/collections/common/tokenizers/aggregate_tokenizer.py::AggregateTokenizer` class AggregatedTTSTokenizer: def __init__(self, tokenizers: List[Union[BaseTokenizer, PreTrainedTokenizerBase]], tokenizer_names: List[str]): """A simple aggregated tokenizer. Aggregates multiple tokenizers into one by combining (simply concatenating) @@ -1126,3 +1128,4 @@ def encode(self, text: str, tokenizer_name: str) -> List[int]: def decode(self, tokens: List[int], tokenizer_name: str) -> str: tokenizer = self.tokenizers[tokenizer_name] return tokenizer.decode([token - self.toknizer_offsets[tokenizer_name] for token in tokens]) + diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index e3e7363ee892..13054b4b9ae2 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -396,7 +396,7 @@ def __init__( max_duration=max_duration, volume_norm=volume_norm, ) - self.bos_id = bos_id + self.bos_id = bos_id # TODO @xueyang: this should be removed since no other places used it. self.eos_id = eos_id self.audio_bos_id = audio_bos_id self.audio_eos_id = audio_eos_id @@ -720,6 +720,7 @@ def collate_fn(self, batch: List[dict]): if len(context_audio_codes_list) > 0: batch_context_audio_codes_len = torch.IntTensor(context_audio_codes_len_list) context_audio_codes_max_len = int(batch_context_audio_codes_len.max().item()) + # TODO @xueyang: verify if batch_context_audio_codes are integer. batch_context_audio_codes = stack_tensors(context_audio_codes_list, max_lens=[context_audio_codes_max_len]) batch_dict['context_audio_codes'] = batch_context_audio_codes batch_dict['context_audio_codes_lens'] = batch_context_audio_codes_len @@ -727,6 +728,7 @@ def collate_fn(self, batch: List[dict]): if self.use_text_conditioning_tokenizer: batch_context_text_tokens_len = torch.IntTensor(context_text_tokens_len_list) context_text_tokens_max_len = int(batch_context_text_tokens_len.max().item()) + # TODO @xueyang: potential bugs if self.tokenizer.pad is not 0.0. verify if batch_context_text_tokens are integer. batch_context_text_tokens = stack_tensors(context_text_tokens_list, max_lens=[context_text_tokens_max_len]) batch_dict['context_text_tokens'] = batch_context_text_tokens batch_dict['context_text_tokens_lens'] = batch_context_text_tokens_len diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index 0a6b8e9f66b7..7daa0af57e61 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,48 +12,67 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -from pathlib import Path -from typing import List, Optional +import random +import re +from typing import Dict, List, Union -import librosa +import numpy as np import torch -import random -from lhotse.dataset.collation import collate_vectors as collate_vectors_lhotse +from hydra.utils import instantiate +from lhotse import CutSet +from lhotse.dataset.collation import collate_matrices, collate_vectors from megatron.core import parallel_state -from omegaconf.omegaconf import OmegaConf +from omegaconf import DictConfig +from transformers import AutoTokenizer, T5Tokenizer from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config -from nemo.collections.tts.parts.utils.tts_dataset_utils import beta_binomial_prior_distribution, stack_tensors +from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer +from nemo.collections.tts.parts.utils.tts_dataset_utils import ( + beta_binomial_prior_distribution, + normalize_volume, + stack_tensors, +) from nemo.utils import logging -from nemo.utils.decorators import experimental - - -def collate_vectors(items, max_length: int, padding_value): - vectors = collate_vectors_lhotse(items, padding_value=padding_value) - if max_length > vectors.size(1): - vectors = torch.cat( - [vectors, padding_value * torch.ones(vectors.size(0), max_length - vectors.size(1), dtype=vectors.dtype)], - dim=1, - ) - if items[0].shape[0] < 1: - vectors = vectors.long() - return vectors - -def normalize_volume_torch(audio, volume_level: float = 0.95): - """Apply peak normalization to the input audio. - """ - if not (0.0 <= volume_level <= 1.0): - raise ValueError(f"Volume must be in range [0.0, 1.0], received {volume_level}") - if audio.size == 0: - return audio +SUPPORTED_CODEC_MODEL_NAMES = ["21fpsCausalDecoder", "12fpsCausalDecoder"] + + +def setup_tokenizers(all_tokenizers_config, use_text_conditioning_tokenizer, mode='train'): + # Being used in both model and worker_init_fn, so it is defined here + # Returns two tokenizers: one for TTS transcript and one for conditioning text (if needed) + tokenizers = [] + tokenizer_names = [] + for tokenizer_name in all_tokenizers_config: + tokenizer_config = all_tokenizers_config[tokenizer_name] + if tokenizer_config._target_ == 'AutoTokenizer': + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.pretrained_model) + else: + text_tokenizer_kwargs = {} + if "g2p" in tokenizer_config: + text_tokenizer_kwargs["g2p"] = instantiate(tokenizer_config.g2p) + tokenizer = instantiate(tokenizer_config, **text_tokenizer_kwargs) + # TODO @xueyang: is it really necessary to set phone probability to 1.0 for test mode? + if mode == 'test' and hasattr(tokenizer, "set_phone_prob"): + tokenizer.set_phone_prob(1.0) + tokenizers.append(tokenizer) + tokenizer_names.append(tokenizer_name) + + aggregated_tokenizer = AggregatedTTSTokenizer(tokenizers, tokenizer_names) # TTS Transcript tokenizer + text_conditioning_tokenizer = None - max_sample = torch.max(torch.abs(audio)) - if max_sample == 0: - return audio + if use_text_conditioning_tokenizer: + # TODO: make this configurable + # Conditioning text tokenizer + text_conditioning_tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") + + return aggregated_tokenizer, text_conditioning_tokenizer + + +def check_speaker_format(item: str): + # enforce the format as example like "| Language:en Dataset:HiFiTTS Speaker:9136_other |". + pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" + return bool(re.match(pattern, item)) - return volume_level * (audio / torch.max(torch.abs(audio))) def build_lhotse_dataloader(dataset, data_cfg, is_eval=False): """Buld dataloader given an input dataset.""" @@ -67,250 +86,415 @@ def build_lhotse_dataloader(dataset, data_cfg, is_eval=False): class MagpieTTSLhotseDataset(torch.utils.data.Dataset): """ - Class for processing and loading text to speech training examples. + A PyTorch Dataset for loading and processing Text-to-Speech data for + MagpieTTS models using Lhotse CutSets, specifically designed for datasets + with text or audio context. But either context can be optional. + + This dataset expects Lhotse Cut objects where each cut represents a + target utterance along with its preceding context. Context can be + audio (preferred) or text. It handles loading either pre-computed audio + codes or raw audio waveforms, applying volume normalization, and tokenizing + text transcripts. Context audio/codes are sliced or repeated to fit within + a specified duration range. Optionally, it loads 16kHz audio suitable for + speaker verification models and calculates alignment priors. + + Tokenizers (for target text and optional context text) are initialized lazily + within each dataloader worker process upon first access. Args: - sample_rate: Sample rate to load audio as. If the audio is stored at a different sample rate, then it will - be resampled. - text_tokenizer: Tokenizer to apply to the text field. - speaker_path: Optional, path to JSON file with speaker indices, for multi-speaker training. Can be created with - scripts.dataset_processing.tts.create_speaker_map.py - featurizers: Optional, list of featurizers to load feature data from. Should be the same config provided - when running scripts.dataset_processing.tts.compute_features.py before training. - feature_processors: Optional, list of feature processors to run on training examples. - align_prior_hop_length: Optional int, hop length of audio features. - If provided alignment prior will be calculated and included in batch output. Must match hop length - of audio features used for training. - min_duration: Optional float, if provided audio files in the training manifest shorter than 'min_duration' - will be ignored. - max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration' - will be ignored. - volume_norm: Whether to apply volume normalization to loaded audio. + sample_rate (int): Target sample rate for loading audio. Audio will be + resampled if necessary. + volume_norm (bool): If True, applies peak volume normalization to audio + waveforms. Defaults to True. + codec_model_downsample_factor (int): The total downsampling factor of the + audio codec model used to generate codes. Used for padding audio + and calculating number of codec frames. + codec_model_name (str): Name identifier for the codec model, used to + determine the field name for loading cached codes (e.g., "codes_21fpsCausalDecoder"). + Defaults to "21fpsCausalDecoder". Supported values defined in + `SUPPORTED_CODEC_MODEL_NAMES`. + audio_bos_id (int): Token ID representing the beginning-of-sequence (BOS) for + target audio codes. + audio_eos_id (int): Token ID representing the end-of-sequence (EOS) for target + audio codes. + context_audio_bos_id (int): Token ID representing the beginning-of-sequence (BOS) + for context audio codes. + context_audio_eos_id (int): Token ID representing the end-of-sequence (EOS)for + context audio codes. + num_audio_codebooks (int): Number of codebooks used by the audio codec model. + Needed for creating dummy context codes if necessary. + prior_scaling_factor (Optional[float]): Scaling factor for the beta-binomial + alignment prior calculation. If None, priors are not computed. Defaults to None. + load_cached_codes_if_available (bool): If True, attempts to load pre-computed + audio codes from custom fields in the Lhotse Cut (e.g., 'codes_21fpsCausalDecoder', + 'context_codes_21fpsCausalDecoder'). Falls back to loading audio if codes + are not found. Defaults to True. + dataset_type (str): Specifies the mode ('train' or 'test'), mainly affecting + tokenizer settings like phoneme probability. Defaults to 'train'. + load_16khz_audio (bool): If True, loads 16kHz audio suitable for speaker + verification models. It prioritizes context audio ('context_recording' field) + if available, otherwise uses the target audio ('recording' field). + Defaults to True. + pad_context_text_to_max_duration (bool): If True and `use_text_conditioning_tokenizer` + is True, pads the tokenized context text to a length derived from + `context_duration_max`. Defaults to False. + context_duration_min (float): Minimum duration (in seconds) for the context + audio/codes. Context shorter than this will be repeated. Defaults to 3.0. + context_duration_max (float): Maximum duration (in seconds) for the context + audio/codes. Context longer than this will be sliced randomly. Defaults to 10.0. + use_text_conditioning_tokenizer (bool): If True, enables processing of context + text using a separate tokenizer (currently T5Tokenizer). Expects context text + in `cut.supervisions[0].custom['context_text']`. Defaults to False. + tokenizer_config (Optional[DictConfig]): Configuration for the text tokenizers. + Used for lazy initialization within workers. Must be provided if tokenizers + are not set externally. Defaults to None. """ def __init__( self, sample_rate: int, - align_prior_hop_length: Optional[int] = None, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, volume_norm: bool = True, codec_model_downsample_factor: int = None, - bos_id: int = None, - eos_id: int = None, + codec_model_name: str = "21fpsCausalDecoder", audio_bos_id: int = None, audio_eos_id: int = None, + context_audio_bos_id: int = None, + context_audio_eos_id: int = None, + num_audio_codebooks: int = None, prior_scaling_factor: float = None, load_cached_codes_if_available: bool = True, dataset_type: str = 'train', - tokenizer_config=None, load_16khz_audio: bool = True, - use_text_conditioning_tokenizer: bool = False, pad_context_text_to_max_duration: bool = False, context_duration_min: float = 3.0, - context_duration_max: float = 10.0 + context_duration_max: float = 10.0, + use_text_conditioning_tokenizer: bool = False, + tokenizer_config: DictConfig = None, ): super().__init__() self.sample_rate = sample_rate - self.text_tokenizer = None - self.align_prior_hop_length = align_prior_hop_length self.volume_norm = volume_norm - - self.bos_id = bos_id - self.eos_id = eos_id self.audio_bos_id = audio_bos_id self.audio_eos_id = audio_eos_id + self.context_audio_bos_id = context_audio_bos_id + self.context_audio_eos_id = context_audio_eos_id + + if codec_model_name not in SUPPORTED_CODEC_MODEL_NAMES: + raise ValueError(f"Invalid `codec_model_name`: {codec_model_name}.") + self.codec_model_name = codec_model_name self.codec_model_downsample_factor = codec_model_downsample_factor + self.num_audio_codebooks = num_audio_codebooks + self.include_align_prior = prior_scaling_factor is not None self.prior_scaling_factor = prior_scaling_factor self.load_cached_codes_if_available = load_cached_codes_if_available - self.dataset_type = dataset_type - self.tokenizer_config = tokenizer_config + self.dataset_type = dataset_type # 'train' or 'test' self.load_16khz_audio = load_16khz_audio self.use_text_conditioning_tokenizer = use_text_conditioning_tokenizer - self.text_conditioning_tokenizer = None self.pad_context_text_to_max_duration = pad_context_text_to_max_duration self.context_duration_min = context_duration_min self.context_duration_max = context_duration_max + self.tokenizer_config = tokenizer_config + self.text_tokenizer = None + self.text_conditioning_tokenizer = None - def __getitem__(self, cuts): - cuts = cuts.sort_by_duration() - - logging.debug(f"Len: {len(cuts)}") - - # load audios and text - num_codec_frames = [] - align_priors = [] - context_audios = [] - context_audios_lens = [] - target_audios = [] - target_audios_lens = [] - target_audios_16khz = [] - target_audios_16khz_lens = [] - context_text_tokens = [] - context_text_tokens_lens = [] - has_text_context_list = [] - target_text_tokens = [] - target_text_tokens_lens = [] - - for i, cut in enumerate(cuts): - # load target/answer audio - answer_audio = torch.FloatTensor(cut.target_audio.resample(self.sample_rate).load_audio()).squeeze(0) - if self.volume_norm: - answer_audio = normalize_volume_torch(answer_audio) - - answer_audio = torch.nn.functional.pad( - answer_audio, - (0, self.codec_model_downsample_factor - (answer_audio.shape[0] % self.codec_model_downsample_factor)), - value=0 - ).unsqueeze(0) - - answer_audio_len = answer_audio.shape[1] - target_audios.append(answer_audio) - target_audios_lens.append(answer_audio_len) - num_frames = int(answer_audio_len / self.codec_model_downsample_factor) + 1 # +1 for EOS - num_codec_frames.append(num_frames) - - # load context audio - context_audio = torch.FloatTensor(cut.resample(self.sample_rate).load_audio()).squeeze(0) - if self.volume_norm: - context_audio = normalize_volume_torch(context_audio) - - context_audio = torch.nn.functional.pad( - context_audio, - (0, self.codec_model_downsample_factor - (context_audio.shape[0] % self.codec_model_downsample_factor)), - value=0 - ).unsqueeze(0) - context_audios_len = context_audio.shape[1] - context_audios.append(context_audio) - context_audios_lens.append(context_audios_len) - - # load context text - if cut.supervisions[0].speaker == "user": - if self.use_text_conditioning_tokenizer: - context_text = cut.supervisions[0].text - context_tokenizer = self.text_conditioning_tokenizer if self.text_conditioning_tokenizer else self.text_tokenizer - # check if the text is not empty - if context_text.replace(" ", ""): - context_text = self.text_conditioning_tokenizer(context_text)['input_ids'] - has_text_context_list.append(True) - else: - context_text = self.text_conditioning_tokenizer("[NO TEXT CONTEXT]")['input_ids'] - has_text_context_list.append(False) - - if self.pad_context_text_to_max_duration: - _required_len = int(self.context_duration_max * self.sample_rate / self.codec_model_downsample_factor) + 2 # +2 for BOS and EOS - if len(context_text) < _required_len: - _pad_id = self.text_conditioning_tokenizer.pad_token_id - context_text += [_pad_id] * (_required_len - len(context_text)) - else: - context_text = context_text[:_required_len] - - context_text = torch.tensor(context_text, dtype=torch.int32) - context_text_len = context_text.shape[0] - context_text_tokens.append(context_text) - context_text_tokens_lens.append(context_text_len) + def get_num_audio_samples_to_slice(self, duration, sample_rate): + num_codec_frames = int(duration * sample_rate / self.codec_model_downsample_factor) + num_audio_samples = num_codec_frames * self.codec_model_downsample_factor + return num_audio_samples + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: + # layze initialize tokenizers. The first time any specific worker + # process calls this function, on its copy of the dataset, the + # tokenizers are created for that worker. All subsequent calls + # to this function will reuse the tokenizers. This equivilent to + # the `worker_init_fn` in MagpieTTSModel. + if self.text_tokenizer is None: + # First time this worker is accessing the dataset, initialize the + # tokenizers. If called by the main process (num_workers=0), worker_info will be None. + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + logging.info(f"Worker {worker_id} initializing tokenizers...") + self.text_tokenizer, self.text_conditioning_tokenizer = setup_tokenizers( + all_tokenizers_config=self.tokenizer_config, + use_text_conditioning_tokenizer=self.use_text_conditioning_tokenizer, + mode=self.dataset_type, + ) + self.bos_id = len(self.text_tokenizer.tokens) + self.eos_id = self.bos_id + 1 + self.pad_id = self.text_tokenizer.pad + + # define list to store batched information + dataset_name_list = [] + audio_list = [] + audio_len_list = [] + audio_list_16khz = [] + audio_len_list_16khz = [] + token_list = [] + token_len_list = [] + prior_list = [] + audio_codes_list = [] + audio_codes_len_list = [] + context_audio_list = [] + context_audio_len_list = [] + context_audio_codes_list = [] + context_audio_codes_len_list = [] + context_text_tokens_list = [] + context_text_tokens_len_list = [] + context_has_text_context_list = [] + reward_list = [] + raw_text_list = [] + target_codes_field = f"codes_{self.codec_model_name}" + context_codes_field = f"context_codes_{self.codec_model_name}" + for cut in cuts: + speaker = cut.supervisions[0].speaker + if not check_speaker_format(speaker): + raise ValueError(f"Invalid format in cut.supervisions[0].speaker: {speaker}") + dataset_name = speaker.strip().split()[2].split(":")[-1] + dataset_name_list.append(dataset_name) + + # target audio or target codes + if self.load_cached_codes_if_available and cut.has_custom(target_codes_field): + audio_codes = torch.from_numpy(cut.load_custom(target_codes_field)).long() # (8, T) + audio_bos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_bos_id, dtype=audio_codes.dtype) + audio_eos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_eos_id, dtype=audio_codes.dtype) + audio_codes = torch.cat([audio_bos_tensor, audio_codes, audio_eos_tensor], dim=1) + audio_codes_len = audio_codes.shape[1] + spec_len = audio_codes.shape[1] + 1 # +1 for EOS + audio_codes_list.append( + audio_codes.T + ) # transpose to (T, 8) in order to use collate_matrices to process batch. + audio_codes_len_list.append(audio_codes_len) + else: + # Only load audio if codes are not available + audio_array = cut.recording.resample(self.sample_rate).load_audio().squeeze(0) + if self.volume_norm: + audio_array = normalize_volume(audio_array) + audio = torch.from_numpy(audio_array) + # Pad audio to be multiple of downsample factor + audio = torch.nn.functional.pad( + audio, + (0, self.codec_model_downsample_factor - (audio.shape[0] % self.codec_model_downsample_factor)), + value=0, + ) + audio_len = audio.shape[0] + spec_len = int(audio_len / self.codec_model_downsample_factor) + 1 # +1 for EOS + audio_list.append(audio) + audio_len_list.append(audio_len) + + # context audio or context codes + if self.load_cached_codes_if_available and cut.has_custom(context_codes_field): + # TODO @xueyang: dev branch applied Tensor.long(), i.e. torch.int64 which is not necessary. + # load audios and text + context_audio_codes = torch.from_numpy(cut.load_custom(context_codes_field)).long() # (8, T) + # Sample random duration between self.context_duration_min and self.context_duration_max + _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) + _num_frames_to_slice = int( + _context_duration_to_slice * self.sample_rate / self.codec_model_downsample_factor + ) + if _num_frames_to_slice < context_audio_codes.shape[1]: + start_idx = random.randint(0, context_audio_codes.shape[1] - _num_frames_to_slice) + context_audio_codes = context_audio_codes[:, start_idx : start_idx + _num_frames_to_slice] + else: + # Repeat the audio if it is shorter than the desired duration + _num_repeats = int(np.ceil(_num_frames_to_slice / context_audio_codes.shape[1])) + # context_audio_codes is a tensor of shape (num_codebooks, T) + context_audio_codes_repeated = context_audio_codes.repeat(1, _num_repeats) + context_audio_codes = context_audio_codes_repeated[:, :_num_frames_to_slice] + + context_bos_tensor = torch.full( + (context_audio_codes.shape[0], 1), self.context_audio_bos_id, dtype=context_audio_codes.dtype + ) + context_eos_tensor = torch.full( + (context_audio_codes.shape[0], 1), self.context_audio_eos_id, dtype=context_audio_codes.dtype + ) + context_audio_codes = torch.cat([context_bos_tensor, context_audio_codes, context_eos_tensor], dim=1) + context_audio_codes_len = context_audio_codes.shape[1] + context_audio_codes_list.append( + context_audio_codes.T + ) # transpose to (T, 8) in order to use collate_matrices to process batch. + context_audio_codes_len_list.append(context_audio_codes_len) + elif cut.has_custom("context_recording"): + # Only load audio if codes are not available + context_audio_array = cut.context_recording.resample(self.sample_rate).load_audio().squeeze(0) + if self.volume_norm: + context_audio_array = normalize_volume(context_audio_array) + _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) + _num_samples_to_slice = self.get_num_audio_samples_to_slice( + _context_duration_to_slice, self.sample_rate + ) + if _num_samples_to_slice < len(context_audio_array): + start_idx = random.randint(0, len(context_audio_array) - _num_samples_to_slice) + context_audio_array = context_audio_array[start_idx : start_idx + _num_samples_to_slice] + else: + # Repeat the audio if it is shorter than the desired duration + _num_repeats = int(np.ceil(_num_samples_to_slice / len(context_audio_array))) + context_audio_array = np.tile(context_audio_array, _num_repeats) + context_audio_array = context_audio_array[:_num_samples_to_slice] + context_audio = torch.from_numpy(context_audio_array) + context_audio_len = context_audio.shape[0] + context_audio_list.append(context_audio) + context_audio_len_list.append(context_audio_len) else: - raise Exception("First speaker should be user") - - if cut.supervisions[1].speaker == "agent": - target_text = cut.supervisions[1].text - # check if the text is not empty - if target_text.replace(" ", ""): - tokenizer_name = "english_phoneme" # Default to english phoneme tokenizer - if getattr(cut, "tokenizer_names", None): - # Pick a random tokenizer from the list of tokenizers - tokenizer_name = random.choice(cut.tokenizer_names) - - target_text = self.text_tokenizer.encode(text=target_text, tokenizer_name=tokenizer_name) - target_text = target_text + [self.eos_id] + # We always want to have context_audio_codes if available for multi-encoder model. These are ignored + # for singlencoder model. + # If context audio is not available, just use a dummy context_audio_codes + # (Will be used in text context scenario) + # TODO @xueyang: verified that this block should cover below 3 conditions which were handled well. + # 1. load_cached_codes_if_available and ["context_audio_codes_path", "context_audio_filepath"] not in data.manifest_entry; + # assign to example["context_audio_codes"] and example["context_audio_codes_len"] + # 2. load_cached_codes_if_available is not True and "context_audio_codes_path" in data.manifest_entry; + # assign to example["context_audio"] and example["context_audio_len"] + # 3. load_cached_codes_if_available is not True and ["context_audio_codes_path", "context_audio_filepath"] not in data.manifest_entry; + # assign to example["context_audio"] and example["context_audio_len"] + if self.load_cached_codes_if_available: + context_bos_tensor = torch.full( + (self.num_audio_codebooks, 1), self.context_audio_bos_id, dtype=torch.int32 + ) + context_eos_tensor = torch.full( + (self.num_audio_codebooks, 1), self.context_audio_eos_id, dtype=torch.int32 + ) + context_audio_codes = torch.cat([context_bos_tensor, context_eos_tensor], dim=1) + context_audio_codes_len = context_audio_codes.shape[1] + context_audio_codes_list.append( + context_audio_codes.T + ) # transpose to (T, 8) in order to use collate_matrices to process batch. + context_audio_codes_len_list.append(context_audio_codes_len) else: - target_text = [self.eos_id] + # @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes + context_audio = torch.zeros(self.codec_model_downsample_factor, dtype=torch.float32) + context_audio_len = context_audio.shape[0] + context_audio_list.append(context_audio) + context_audio_len_list.append(context_audio_len) - target_text = torch.tensor(target_text, dtype=torch.int32) - target_text_len = target_text.shape[0] - target_text_tokens.append(target_text) - target_text_tokens_lens.append(target_text_len) + if self.load_16khz_audio: + if cut.has_custom("context_recording"): + # use context audio for SV model + audio_array_16khz = cut.context_recording.resample(16_000).load_audio().squeeze(0) + if self.volume_norm: + audio_array_16khz = normalize_volume(audio_array_16khz) + else: + # Otherwise, load the target audio for SV model. + audio_array_16khz = cut.recording.resample(16_000).load_audio().squeeze(0) + if self.volume_norm: + audio_array_16khz = normalize_volume(audio_array_16khz) + _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) + _num_samples_to_slice = int(_context_duration_to_slice * 16_000) + if _num_samples_to_slice < len(audio_array_16khz): + start_idx = random.randint(0, len(audio_array_16khz) - _num_samples_to_slice) + audio_array_16khz = audio_array_16khz[start_idx : start_idx + _num_samples_to_slice] + audio_16khz = torch.from_numpy(audio_array_16khz) + audio_len_16khz = audio_16khz.shape[0] + audio_list_16khz.append(audio_16khz) + audio_len_list_16khz.append(audio_len_16khz) + + if self.use_text_conditioning_tokenizer: + if cut.supervisions[0].has_custom("context_text"): + context_text_tokens = self.text_conditioning_tokenizer(cut.supervisions[0].context_text)[ + 'input_ids' + ] + has_text_context = True + else: + context_text_tokens = self.text_conditioning_tokenizer("[NO TEXT CONTEXT]")['input_ids'] + has_text_context = False + if self.pad_context_text_to_max_duration: + _required_len = ( + int(self.context_duration_max * self.sample_rate / self.codec_model_downsample_factor) + 2 + ) # +2 for BOS and EOS + if len(context_text_tokens) < _required_len: + _pad_id = self.text_conditioning_tokenizer.pad_token_id + context_text_tokens += [_pad_id] * (_required_len - len(context_text_tokens)) + else: + # TODO @xueyang: It seems counter intuition if trimming the text context tokens to the required + # context length. For example, the context_tokens after trimming may correspond to the partial + # context_text like "Speaker and Emotion: | Language:en Dataset(trimmed :Riva Speaker:Rodney_DROP |)" + context_text_tokens = context_text_tokens[:_required_len] + context_text_tokens = torch.tensor(context_text_tokens, dtype=torch.int32) + context_text_tokens_len = context_text_tokens.shape[0] + context_text_tokens_list.append(context_text_tokens) + context_text_tokens_len_list.append(context_text_tokens_len) + context_has_text_context_list.append(has_text_context) + + # tokenize transcript + # TODO @xueyang: temporally apply raw text. will check to change if normalized text is available. + raw_text = cut.supervisions[0].text + raw_text_list.append(raw_text) + if cut.has_custom("tokenizer_names"): + # Pick a random tokenizer from the list of tokenizers + tokenizer_name = random.choice(cut.tokenizer_names) else: - raise Exception("Second speaker should be agent") + tokenizer_name = "english_phoneme" # Default to english phoneme tokenizer + tokens = self.text_tokenizer.encode(text=raw_text, tokenizer_name=tokenizer_name) + tokens = tokens + [self.eos_id] # Not adding BOS id + tokens = torch.tensor(tokens, dtype=torch.int32) + text_len = tokens.shape[0] + token_list.append(tokens) + token_len_list.append(text_len) if self.include_align_prior: # align_prior = self.beta_binomial_interpolator(spec_len, text_len) - align_prior = beta_binomial_prior_distribution(phoneme_count=target_text_len, mel_count=num_frames, scaling_factor=self.prior_scaling_factor) + align_prior = beta_binomial_prior_distribution( + phoneme_count=text_len, mel_count=spec_len, scaling_factor=self.prior_scaling_factor + ) align_prior = torch.tensor(align_prior, dtype=torch.float32) - align_priors.append(align_prior) - - if self.load_16khz_audio: - target_audio_16khz = librosa.resample(answer_audio.squeeze(0).numpy(), orig_sr=self.sample_rate, target_sr=16000) - target_audio_16khz = torch.FloatTensor(target_audio_16khz).unsqueeze(0) - target_audio_16khz_len = target_audio_16khz.shape[1] - target_audios_16khz.append(target_audio_16khz) - target_audios_16khz_lens.append(target_audio_16khz_len) - - # collate target/agent audios - target_audios = collate_vectors( - [a.squeeze(0) for a in target_audios], max_length=max(target_audios_lens), padding_value=0.0 - ).float() - target_audios_lens = torch.IntTensor(target_audios_lens) - num_codec_frames = torch.IntTensor(num_codec_frames) - - # collate context/user audios - context_audios = collate_vectors( - [a.squeeze(0) for a in context_audios], max_length=max(context_audios_lens), padding_value=0.0 - ).float() - context_audios_lens = torch.IntTensor(context_audios_lens) - - # collate context/user text - if self.use_text_conditioning_tokenizer: - context_text_tokens = collate_vectors(context_text_tokens, max_length=max(context_text_tokens_lens), padding_value=self.text_tokenizer.pad) - context_text_tokens_lens = torch.IntTensor(context_text_tokens_lens) + prior_list.append(align_prior) - # collate target/agent text - target_text_tokens = collate_vectors(target_text_tokens, max_length=max(target_text_tokens_lens), padding_value=self.text_tokenizer.pad) - target_text_tokens_lens = torch.IntTensor(target_text_tokens_lens) - - # collate align prior - if self.include_align_prior: - spec_max_len = max([prior.shape[0] for prior in align_priors]) - text_max_len = max([prior.shape[1] for prior in align_priors]) - align_priors = stack_tensors(align_priors, max_lens=[text_max_len, spec_max_len],) - - # collate 16khz target/agent audio - if self.load_16khz_audio: - target_audios_16khz = collate_vectors( - [a.squeeze(0) for a in target_audios_16khz], max_length=max(target_audios_16khz_lens), padding_value=0.0 - ).float() - target_audios_16khz_lens = torch.IntTensor(target_audios_16khz_lens) + if cut.supervisions[0].has_custom("reward"): + reward = cut.supervisions[0].reward + reward_list.append(reward) + # collate vectors and matrices here. batch_dict = { - # "dataset_names": dataset_names, - # "audio_filepaths": audio_filepath_list, - "sample_ids": list(cuts.ids), - "text": target_text_tokens, - "text_lens": target_text_tokens_lens, - 'audio': target_audios, - 'audio_lens': target_audios_lens, - # 'audio_codes': batch_audio_codes - # 'audio_codes_lens': batch_audio_codes_len - 'context_audio': context_audios, - 'context_audio_lens': context_audios_lens, - # 'context_audio_codes': batch_context_audio_codes - # 'context_audio_codes_lens': batch_context_audio_codes_len + "dataset_names": dataset_name_list, + "raw_texts": raw_text_list, + "text": collate_vectors(token_list, padding_value=self.pad_id), # (B, max_len) + "text_lens": torch.IntTensor(token_len_list), } - if self.include_align_prior: - batch_dict["align_prior_matrix"] = align_priors - - if self.load_16khz_audio: - batch_dict['audio_16khz'] = target_audios_16khz - batch_dict['audio_lens_16khz'] = target_audios_16khz_lens + # audio for SV. + if len(audio_list_16khz) > 0: + batch_dict["audio_16khz"] = collate_vectors(audio_list_16khz, padding_value=0.0) + batch_dict["audio_lens_16khz"] = torch.IntTensor(audio_len_list_16khz) + + # target audio and codes + if len(audio_list) > 0: + batch_dict["audio"] = collate_vectors(audio_list, padding_value=0.0) + batch_dict["audio_lens"] = torch.IntTensor(audio_len_list) + if len(audio_codes_list) > 0: + # transpose back to (B, 8, T) from (B, T, 8). + batch_dict["audio_codes"] = collate_matrices(audio_codes_list, padding_value=0).transpose(1, 2) + batch_dict["audio_codes_lens"] = torch.IntTensor(audio_codes_len_list) + + # context audio and codes + if len(context_audio_list) > 0: + batch_dict["context_audio"] = collate_vectors(context_audio_list, padding_value=0.0) + batch_dict["context_audio_lens"] = torch.IntTensor(context_audio_len_list) + if len(context_audio_codes_list) > 0: + # transpose back to (B, 8, T) from (B, T, 8). + batch_dict["context_audio_codes"] = collate_matrices(context_audio_codes_list, padding_value=0).transpose(1, 2) + batch_dict["context_audio_codes_lens"] = torch.IntTensor(context_audio_codes_len_list) if self.use_text_conditioning_tokenizer: - batch_dict['context_text_tokens'] = context_text_tokens - batch_dict['context_text_len'] = context_text_tokens_lens - batch_dict['has_text_context'] = torch.BoolTensor(has_text_context_list) + batch_dict['context_text_tokens'] = collate_vectors( + tensors=context_text_tokens_list, padding_value=self.text_conditioning_tokenizer.pad_token_id + ) + batch_dict['context_text_tokens_lens'] = torch.IntTensor(context_text_tokens_len_list) + batch_dict['has_text_context'] = torch.BoolTensor(context_has_text_context_list) - return batch_dict + if self.include_align_prior: + spec_max_len = max([prior.shape[0] for prior in prior_list]) + text_max_len = max([prior.shape[1] for prior in prior_list]) + batch_dict["align_prior_matrix"] = stack_tensors(prior_list, max_lens=[text_max_len, spec_max_len]) + + if len(reward_list) > 0: + batch_dict['rewards'] = torch.FloatTensor(reward_list) + # Assert only ONE of context_audio or context_audio_codes in the batch + assert ('audio' in batch_dict) ^ ('audio_codes' in batch_dict) - def collate_fn(self, batch: List[dict]): - return batch + # Assert only ONE of context_audio or context_audio_codes in the batch + if 'context_audio' in batch_dict: + assert 'context_audio_codes' not in batch_dict + if 'context_audio_codes' in batch_dict: + assert 'context_audio' not in batch_dict + + return batch_dict diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index b82c07a64e68..d2beff1ae0ca 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -28,11 +28,11 @@ from omegaconf import DictConfig, open_dict from torch import nn from torch.utils.data import get_worker_info -from transformers import AutoTokenizer, T5Tokenizer import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate -from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 @@ -54,36 +54,6 @@ HAVE_WANDB = False -def setup_tokenizers(all_tokenizers_config, use_text_conditioning_tokenizer, mode='train'): - # Being used in both model and worker_init_fn, so it is defined here - # Returns two tokenizers: one for TTS transcript and one for conditioning text (if needed) - tokenizers = [] - tokenizer_names = [] - for tokenizer_name in all_tokenizers_config: - tokenizer_config = all_tokenizers_config[tokenizer_name] - if tokenizer_config._target_ == 'AutoTokenizer': - tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.pretrained_model) - else: - text_tokenizer_kwargs = {} - if "g2p" in tokenizer_config: - text_tokenizer_kwargs["g2p"] = instantiate(tokenizer_config.g2p) - tokenizer = instantiate(tokenizer_config, **text_tokenizer_kwargs) - if mode == 'test' and hasattr(tokenizer, "set_phone_prob"): - tokenizer.set_phone_prob(1.0) - tokenizers.append(tokenizer) - tokenizer_names.append(tokenizer_name) - - aggregated_tokenizer = AggregatedTTSTokenizer(tokenizers, tokenizer_names) # TTS Transcript tokenizer - text_conditioning_tokenizer = None - - if use_text_conditioning_tokenizer: - # TODO: make this configurable - # Conditioning text tokenizer - text_conditioning_tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") - - return aggregated_tokenizer, text_conditioning_tokenizer - - def worker_init_fn(worker_id): # For mp.set_start_method("spawn", force=True) # The dataset class should be picklable, so we initialize non-picklable objects here @@ -135,9 +105,14 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): del cfg['text_tokenizer'] self.use_text_conditioning_encoder = cfg.get('use_text_conditioning_encoder', False) - tokenizer, text_conditioning_tokenizer = self._setup_tokenizers(cfg) - self.tokenizer = tokenizer - self.text_conditioning_tokenizer = text_conditioning_tokenizer + # TODO @xueyang: both tokenizers are only used to get some token ids. We + # should kill them to save a small mount of mem resources since dataloader will initialize them + # again after the worker processes are spawned. + self.tokenizer, self.text_conditioning_tokenizer = setup_tokenizers( + all_tokenizers_config=cfg.text_tokenizers, + use_text_conditioning_tokenizer=self.use_text_conditioning_encoder, + mode='train', + ) num_tokens_tokenizer = len(self.tokenizer.tokens) num_tokens = num_tokens_tokenizer + 2 # +2 for BOS and EOS @@ -148,6 +123,10 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.audio_eos_id = cfg.num_audio_tokens_per_codebook - 1 self.context_audio_bos_id = cfg.num_audio_tokens_per_codebook - 2 # For backward compatibility self.context_audio_eos_id = cfg.num_audio_tokens_per_codebook - 1 # For backward compatibility + + if self.use_text_conditioning_encoder: + self.context_text_embedding = nn.Embedding(self.text_conditioning_tokenizer.vocab_size, cfg.embedding_dim) + self.model_type = cfg.get('model_type', 'single_encoder_sv_tts') if self.model_type == 'decoder_context_tts': @@ -239,9 +218,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): else: raise ValueError(f"Unsupported model type {self.model_type}") - if self.use_text_conditioning_encoder: - self.context_text_embedding = nn.Embedding(self.text_conditioning_tokenizer.vocab_size, cfg.embedding_dim) - self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0) alignment_encoder_loss_scale = cfg.get('alignment_encoder_loss_scale', 0.0) @@ -265,12 +241,6 @@ def load_state_dict(self, state_dict, strict=True): # Override to load all the keys except _speaker_verification_model and _codec_model super().load_state_dict(state_dict, strict=False) - def _setup_tokenizers(self, cfg, mode='test'): - tokenizer, text_conditioning_tokenizer = setup_tokenizers( - cfg.text_tokenizers, cfg.use_text_conditioning_encoder, mode=mode - ) - return tokenizer, text_conditioning_tokenizer - def audio_to_codes(self, audio, audio_len, audio_type='target'): # audio: (B, T) # audio_len: (B,) @@ -432,7 +402,6 @@ def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): def sample_codes_from_local_transformer(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0): # dec_output: (B, E) - # import ipdb; ipdb.set_trace() self.local_transformer.reset_cache(use_cache=True) dec_output = dec_output.unsqueeze(1) # (B, 1, E) local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) @@ -1509,9 +1478,9 @@ def on_validation_epoch_end(self): self.log("val/local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True) self.validation_step_outputs.clear() # free memory - def get_dataset(self, cfg, dataset_type): + def get_dataset(self, dataset_cfg, dataset_type): dataset = instantiate( - cfg.dataset, + dataset_cfg.dataset, bos_id=self.bos_id, eos_id=self.eos_id, audio_bos_id=self.audio_bos_id, @@ -1534,40 +1503,87 @@ def get_dataset(self, cfg, dataset_type): ) # This will be used in worker_init_fn for instantiating tokenizer return dataset - def setup_training_data(self, cfg): - dataset = self.get_dataset(cfg, dataset_type='train') - sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) - persistent_workers = True - if cfg.dataloader_params.num_workers == 0: - persistent_workers = False - # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) - dataset.text_tokenizer, dataset.text_conditioning_tokenizer = self._setup_tokenizers(self.cfg) - self._train_dl = torch.utils.data.DataLoader( - dataset, - collate_fn=dataset.collate_fn, - sampler=sampler, - **cfg.dataloader_params, - worker_init_fn=worker_init_fn, - persistent_workers=persistent_workers, + def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.DataLoader: + # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also + # cfg is a classifier-free guidance. + dataset = MagpieTTSLhotseDataset( + sample_rate=self.cfg.sample_rate, + volume_norm=dataset_cfg.volume_norm, + codec_model_downsample_factor=self.cfg.codec_model_downsample_factor, + codec_model_name=self.cfg.codec_model_name, + audio_bos_id=self.audio_bos_id, + audio_eos_id=self.audio_eos_id, + context_audio_bos_id=self.context_audio_bos_id, + context_audio_eos_id=self.context_audio_eos_id, + num_audio_codebooks=self.cfg.num_audio_codebooks, + prior_scaling_factor=self.cfg.prior_scaling_factor, + load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, + dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) + load_16khz_audio=(self.model_type == 'single_encoder_sv_tts'), + pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, + context_duration_min=self.cfg.context_duration_min, + context_duration_max=self.cfg.context_duration_max, + use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder, + tokenizer_config=self.cfg.text_tokenizers, ) - - def _setup_test_dataloader(self, cfg) -> torch.utils.data.DataLoader: - dataset = self.get_dataset(cfg, dataset_type='test') - persistent_workers = True - if cfg.dataloader_params.num_workers == 0: - persistent_workers = False - # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) - dataset.text_tokenizer, dataset.text_conditioning_tokenizer = self._setup_tokenizers(self.cfg, mode='test') - - data_loader = torch.utils.data.DataLoader( - dataset, - collate_fn=dataset.collate_fn, - **cfg.dataloader_params, - worker_init_fn=worker_init_fn, - persistent_workers=persistent_workers, + data_loader = get_lhotse_dataloader_from_config( + config=dataset_cfg.dataset, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=dataset, ) return data_loader + def setup_training_data(self, dataset_cfg): + if dataset_cfg.get("use_lhotse", False): + # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also + # cfg is a classifier-free guidance. + self._train_dl = self.get_lhotse_dataloader(dataset_cfg, mode='train') + else: + dataset = self.get_dataset(dataset_cfg, dataset_type='train') + sampler = dataset.get_sampler(dataset_cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) + persistent_workers = True + if dataset_cfg.dataloader_params.num_workers == 0: + persistent_workers = False + # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) + dataset.text_tokenizer, dataset.text_conditioning_tokenizer = setup_tokenizers( + all_tokenizers_config=self.cfg.text_tokenizers, + use_text_conditioning_tokenizer=self.use_text_conditioning_encoder, + mode='train', + ) + self._train_dl = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + sampler=sampler, + **dataset_cfg.dataloader_params, + worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers, + ) + + def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: + if dataset_cfg.get("use_lhotse", False): + data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test') + else: + dataset = self.get_dataset(dataset_cfg, dataset_type='test') + persistent_workers = True + if dataset_cfg.dataloader_params.num_workers == 0: + persistent_workers = False + # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) + dataset.text_tokenizer, dataset.text_conditioning_tokenizer = setup_tokenizers( + all_tokenizers_config=self.cfg.text_tokenizers, + use_text_conditioning_tokenizer=self.use_kv_cache_for_inference, + mode='test' + ) + + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + **dataset_cfg.dataloader_params, + worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers, + ) + return data_loader + def setup_validation_data(self, cfg): self._validation_dl = self._setup_test_dataloader(cfg) diff --git a/scripts/magpietts/README_lhotse.md b/scripts/magpietts/README_lhotse.md new file mode 100644 index 000000000000..f56eb1cdd4dd --- /dev/null +++ b/scripts/magpietts/README_lhotse.md @@ -0,0 +1,174 @@ +This guidance describes general steps on converting NeMo datasets to Lhotse Shar datasets for training +and validating Magpie-TTS. + +## Creating New Lhotse Shar Data +Step 1: reformatting `speaker` field in the NeMo manifest to pass the format check as the function defined, +```python +def check_speaker_format(item: str): + # enforce the format as example like "| Language:en Dataset:HiFiTTS Speaker:9136_other |". + pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" + return bool(re.match(pattern, item)) +``` + +Step 2: create Lhotse Shar dataset by running, +```bash +# codec +CODEC_MODEL_NAME="21fpsCausalDecoder" +CODEC_MODEL_PATH="/codecs/21fps_causal_codecmodel.nemo" +CODEC_FRAME_RATE=21.5 +SAMPLE_RATE=22050 +PAD_MULTIPLE=1024 + +# trainer +DEVICES=-1 +NUM_NODES=1 +BATCH_SIZE=48 +NUM_WORKERS=10 +SHARD_SIZE=4096 + +# code +CODE_DIR="/workspace/NeMo" + +# NeMo manifest +MANIFEST_PATH="/manifests/hifitts_train_withContextAudioMinDur3MinSSIM0.6.json" + +# audio base dir +AUDIO_BASE_DIR="/audio/hi_fi_tts_v0" + +# save dir for Shar +SAVE_DIR="/data_shar_train" + +echo "*******STARTING********" +cd ${CODE_DIR} +export PYTHONPATH="${CODE_DIR}:${PYTHONPATH}" +echo "Starting Codec Extraction..." +python scripts/magpietts/convert_nemo_to_lhotse_shar.py + --manifest ${MANIFEST_PATH} + --audio_base_dir ${AUDIO_BASE_DIR} + --save_dir ${SAVE_DIR} + --codec_model_name ${CODEC_MODEL_NAME} + --codec_model_path ${CODEC_MODEL_PATH} + --codec_frame_rate ${CODEC_FRAME_RATE} + --sample_rate ${SAMPLE_RATE} + --pad_multiple ${PAD_MULTIPLE} + --devices ${DEVICES} + --num_nodes ${NUM_NODES} + --batch_size ${BATCH_SIZE} + --num_workers ${NUM_WORKERS} + --shard_size ${SHARD_SIZE} +``` + +Step 3: check the files by looking at the folder, +```shell +Examples of shard files: +$ tree data_shar_train/ + - cuts.000000.jsonl.gz # Lhotse manifest. + - codes_21fpsCausalDecoder.000000.tar # target codec codes. + - recording.000000.tar # target audio waveforms. + - context_codes_21fpsCausalDecoder.000000.tar # context audio codec codes. + - context_recording.000000.tar # context audio waveforms. +``` + +When peek one of the item from `cuts.000000.jsonl.gz`, you should expect the structure as, +```python +MonoCut( + id='cut-audio-11614_other-12352-prideofjennico_01_castle_0000', + start=0, + duration=6.16, + channel=0, + supervisions=[ + SupervisionSegment( + id='sup-audio-11614_other-12352-prideofjennico_01_castle_0000', + recording_id='audio-11614_other-12352-prideofjennico_01_castle_0000', + start=0.0, + duration=6.16, + channel=0, + text='late in the year seventeen seventy one as the wind rattles the casements with impotent clutch', + language='en', + speaker='| Language:en Dataset:HiFiTTS Speaker:11614 |', + gender=None, + custom={}, + alignment=None + ) + ], + features=None, + recording=Recording( + id='audio-11614_other-12352-prideofjennico_01_castle_0000', + sources=[ + AudioSource( + type='memory', + channels=[0], + source='' + ) + ], + sampling_rate=44100, + num_samples=271656, + duration=6.16, + channel_ids=[0], + transforms=None + ), + custom={ + 'codes_21fpsCausalDecoder': TemporalArray( + array=Array( + storage_type='memory_npy', + storage_path='', + storage_key='', + shape=[8, 133] + ), + temporal_dim=1, + frame_shift=0.046511627906976744, + start=0 + ), + 'context_codes_21fpsCausalDecoder': TemporalArray( + array=Array( + storage_type='memory_npy', + storage_path='', + storage_key='', + shape=[8, 138] + ), + temporal_dim=1, + frame_shift=0.046511627906976744, + start=0 + ), + 'context_recording': Recording( + id='audio-11614_other-12220-barontrump_31_lockwood_0096', + sources=[ + AudioSource( + type='memory', + channels=[0], + source='' + ) + ], + sampling_rate=44100, + num_samples=282240, + duration=6.4, + channel_ids=[0], + transforms=None + ), + 'shard_origin': PosixPath('cuts.000000.jsonl.gz'), + 'shar_epoch': 0 + } +) +``` + +## Appending New Codec Codes to Existing Lhotse Manifest +TBD. In genenral, the solution is to load existing cuts of shards, attach the new codec codes to the +MonoCut's `custom` field, and write cuts and new codec codes into shard files. This should uses the +same index of shards. + +## (Internal Only) Sharing Slurm Job Sub Scripts to Create Lhotse Shar +All scripts are stored in +https://gitlab-master.nvidia.com/xueyang/nemo-tts-artifacts-registry/-/tree/main/data_prep_lhotse . + +```shell +$ tree . +. +├── extract_audioCodec_21fpsCausalDecoder_eos.sub +├── hifitts2_extract_audioCodec_21fpsCausalDecoder_eos.sub +├── README_lhotse.md +├── reserve_interactive_node.sh +└── submit_jobs_for_all_datasets.sh + +$ bash submit_jobs_for_all_datasets.sh +``` + diff --git a/scripts/magpietts/convert_nemo_to_lhotse_shar.py b/scripts/magpietts/convert_nemo_to_lhotse_shar.py new file mode 100644 index 000000000000..ec3e4b98ea83 --- /dev/null +++ b/scripts/magpietts/convert_nemo_to_lhotse_shar.py @@ -0,0 +1,464 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example entry in `/home/xueyang/workspace/pretrain/data_prep/hifitts2/manifests/train_manifest_withContextAudioMinDur3.json` +{ + "audio_filepath": "100/2315/100_2315_sea_fairies_0812_librivox-01_baum_sea_fairies_0.flac", + "duration": 6.2, + "speaker": "| Language:en Dataset:HiFiTTS2 Speaker:100 |", + "text": "THE oceans are big and broad. I believe two thirds of the earth's surface is covered with water.", + "normalized_text": "THE oceans are big and broad. I believe two thirds of the earth's surface is covered with water.", + "text_source": "book", + "bandwidth": 13092, + "snr1": 41.27, + "snr2": 41.05, + "snr3": 32.58, + "snr4": 22.28, + "is_segmented": true, + "wer": 0, + "cer": 0, + "ins": 0, + "del": 0, + "sub": 0, + "speaker_count": 1, + "chapter_id": "01", + "context_speaker_similarity": 0.9059218168258667, + "context_audio_filepath": "100/2315/100_2315_sea_fairies_0812_librivox-01_baum_sea_fairies_1.flac", + "context_audio_duration": 6.08 +} + +Goal: avoid to save inodes quota by tarring individual files for audio codecs, audio waveforms and/or speaker embeddings. + We can decide if remove audio waveform files later. + +Examples of shard files: +$ tree data-shar-train/ + - cuts.000000.jsonl.gz: add all and exclude unnecessary fields. + - codes_21fpsCausalDecoder.000000.tar + - recording.000000.tar: not used during training, but worth to tar them so save inodes quota and for future applications. + - context_codes_21fpsCausalDecoder.000000.tar + - context_recording.000000.tar + - context_spk_embed.000000.tar (optional): speaker embedding is not used during training/validation. + - spk_embed.000000.tar (optional): speaker embedding is not used during training/validation. +""" + +import argparse +import os +import re +from functools import partial +from pathlib import Path + +import lightning.pytorch as pl +import torch +from lhotse import MonoCut, Recording, SupervisionSegment +from lhotse.shar.writers.shar import AudioTarWriter, SharWriter +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import BasePredictionWriter +from lightning.pytorch.strategies import DDPStrategy +from torch.utils.data import DataLoader, Dataset + +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.tts.models import AudioCodecModel + +MAYBE_EXTRA_METADATA_IN_MANIFEST = ["normalized_text", "speaker_count", "cer", "wer"] + + +def check_speaker_format(item: str): + # enforce the format as example like "| Language:en Dataset:HiFiTTS Speaker:9136_other |". + pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" + return bool(re.match(pattern, item)) + + +def get_recording_id(audio_base_dir: str, path: Path): + # the recording id is defined as the concatenation of relative audio filepath with a hyphen delimiter. + return path.relative_to(audio_base_dir).with_suffix("").as_posix().replace("/", "-") + + +class SharPredictionWriter(BasePredictionWriter): + def __init__( + self, + output_dir: str, + codec_model_name: str, + codec_frame_rate: float, + audio_base_dir: str, + fields: dict, + shard_size: int = 1000, + ): + super().__init__(write_interval="batch") + self.output_dir = output_dir + self.codec_model_name = codec_model_name + self.codec_frame_rate = codec_frame_rate + self.fields = fields + self.shard_size = shard_size + self.batch_counter = 0 + self.shar_writer = None + self.context_recording_writer = None + self.is_initialized = False + self.recording_id_fn = partial(get_recording_id, audio_base_dir) + + # Add a buffer with the shard size to accumulate cuts before writing to disk. + self.cuts_buffer = list() + self.buffer_size = shard_size + + def setup(self, trainer, pl_module, stage=None): + if not self.is_initialized: + # Only initialize the SharWriter and AudioTarWriter on rank 0 + if trainer.global_rank == 0: + os.makedirs(self.output_dir, exist_ok=True) + + # Initialize SharWriter + self.shar_writer = SharWriter( + output_dir=self.output_dir, fields=self.fields, shard_size=self.shard_size + ) + self.shar_writer.__enter__() + + # Initialize AudioTarWriter to store context recording as a workaround. + # TODO @xueyang: Without this, the process would be blocked because, + # When target duration is specified in MonoCut, the error will happen iff context duration < target duration + # mostly because the cut tries to trim the context_recording to the same duration as target. No errors + # were observed when context duration > target duration. Ref is https://nvidia.slack.com/archives/D068LR4TWUW/p1741817511544239 + self.context_recording_writer = AudioTarWriter( + pattern=os.path.join(self.output_dir, "context_recording.%06d.tar"), + shard_size=self.shard_size, + format="flac", + ) + self.context_recording_writer.__enter__() + + self.is_initialized = True + + def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx): + # prepare cuts from each rank + pred_cuts = self.convert_prediction_to_cuts(prediction) + # Gather predictions from all ranks + gathered_objects = [None] * trainer.world_size + torch.distributed.all_gather_object(gathered_objects, pred_cuts) + + # Only rank 0 writes to disk + if trainer.global_rank == 0: + for _pred_cuts, _context_recordings in gathered_objects: + if _pred_cuts is None or _context_recordings is None: + raise RuntimeError("Received None from all_gather_object") + + # Buffer the cuts + self.cuts_buffer.extend(list(zip(_pred_cuts, _context_recordings))) + + # Write when buffer is full + if len(self.cuts_buffer) >= self.buffer_size: + self._write_buffer() + + def _write_buffer(self): + """Write accumulated cuts from buffer""" + for cut, recording in self.cuts_buffer: + self.shar_writer.write(cut) + self.context_recording_writer.write( + key=cut.id, + value=recording.load_audio(), + sampling_rate=recording.sampling_rate, + manifest=recording, + ) + self.batch_counter += 1 + + # Clear the buffer + self.cuts_buffer = list() + + def convert_prediction_to_cuts(self, prediction): + # Extra useful metadata may exist in some manifests, so better to keep them for future usage. + meta_fields = { + meta_field: prediction[meta_field] + for meta_field in MAYBE_EXTRA_METADATA_IN_MANIFEST + if meta_field in prediction + } + + # This should convert predictions to Cut objects for Lhotse + cuts = list() + context_recordings = list() + + # batch process recordings and codes here. + # target recording + target_recordings = [ + Recording.from_file( + path=audio_filepath, + recording_id=self.recording_id_fn, + ) + for audio_filepath in prediction["target_audio_filepath"] + ] + context_recordings = [ + Recording.from_file( + path=audio_filepath, + recording_id=self.recording_id_fn, + ) + for audio_filepath in prediction["context_audio_filepath"] + ] + + # Create supervisions in batch + supervisions = [ + SupervisionSegment( + id=f"sup-{rec.id}", + recording_id=rec.id, + start=0.0, + duration=rec.duration, + channel=0, + text=text, + speaker=spk, + language=lang, + custom={key: val[idx] for key, val in meta_fields.items()} if meta_fields else None, + ) + for idx, (rec, text, spk, lang) in enumerate( + zip(target_recordings, prediction["text"], prediction["speaker"], prediction["language"]) + ) + ] + + # Create cuts in batch + # TODO @xueyang: should file a bug report to `attach_tensor` function. When `temporal_dim=-1`, the tensor is not + # attached correctly. For example, I found that `cuts[0].codes_21fpsCausalDecoder.load()` and + # `cuts[0].load_custom("codes_21fpsCausalDecoder")` returns different arrays, with different shapes. But the former + # returned expected (8,5) shape, while the latter returned (5,5). I also find that, after write shar files, and + # when i load codes using `CutSet.from_shar()` and no matter which load functions I used, they are all shape of (5,5) + # instead of (8,5). In any case, using default `temporal_dim` and `frame_shift` addressed this issue. + cuts = [ + MonoCut( + id=f"cut-{rec.id}", + start=0.0, + duration=rec.duration, + recording=rec, + channel=0, + supervisions=[sup], + custom={"context_recording": context_rec}, + ).attach_tensor( + name=f"codes_{self.codec_model_name}", + data=target_code, + # temporal_dim=1, + # frame_shift=1 / self.codec_frame_rate + ).attach_tensor( + name=f"context_codes_{self.codec_model_name}", + data=context_code, + # temporal_dim=1, + # frame_shift=1 / self.codec_frame_rate + ) + for rec, sup, context_rec, target_code, context_code in zip( + target_recordings, + supervisions, + context_recordings, + prediction["target_codes"], + prediction["context_codes"], + ) + ] + + return cuts, context_recordings + + def teardown(self, trainer, pl_module, stage=None): + # Wait for rank 0 to finish writing + if trainer.world_size > 1: + torch.distributed.barrier() + + # Close the SharWriter and AudioTarWriter on rank 0 + if trainer.global_rank == 0: + # Write any remaining cuts in the buffer before closing + if self.cuts_buffer: + self._write_buffer() + + if self.context_recording_writer is not None: + self.context_recording_writer.close() + if self.shar_writer is not None: + self.shar_writer.close() + + +class AudioDataset(Dataset): + def __init__(self, manifest: str, audio_base_dir: str, sample_rate: int = 22050, pad_multiple: int = 1024): + self.audio_base_dir = audio_base_dir + self.sample_rate = sample_rate + self.pad_multiple = pad_multiple + self.items = read_manifest(manifest) + + def __len__(self): + return len(self.items) + + def get_wav_from_filepath(self, file_path: str): + features = AudioSegment.segment_from_file( + file_path, + target_sr=self.sample_rate, + n_segments=-1, + trim=False, + ) + audio = torch.tensor(features.samples) + audio = torch.nn.functional.pad(audio, (0, self.pad_multiple - audio.size(0) % self.pad_multiple), value=0) + audio_length = torch.tensor(audio.size(0)).long() + return audio, audio_length + + def __getitem__(self, idx): + item = self.items[idx] + if not check_speaker_format(item["speaker"]): + raise ValueError(f"Invalid speaker format at index {idx}: {item}") + target_audio_filepath = os.path.join(self.audio_base_dir, item["audio_filepath"]) + context_audio_filepath = os.path.join(self.audio_base_dir, item["context_audio_filepath"]) + target_audio, target_audio_length = self.get_wav_from_filepath(target_audio_filepath) + context_audio, context_audio_length = self.get_wav_from_filepath(context_audio_filepath) + output_dict = { + "target_audio_filepath": target_audio_filepath, + "target_audio": target_audio, + "target_audio_length": target_audio_length, + "target_audio_duration": item["duration"], + "context_audio_filepath": context_audio_filepath, + "context_audio": context_audio, + "context_audio_length": context_audio_length, + "context_audio_duration": item["context_audio_duration"], + "context_speaker_similarity": item["context_speaker_similarity"], + "speaker": item["speaker"], + "text": item["text"], + "language": item["speaker"].strip().split()[1].split(":")[-1], + } + # Extra useful metadata may exist in some manifests, so better to keep them for future usage. + return self._copy_maybe_extra_metadata(item, output_dict) + + def collate_fn(self, batch): + max_target_audio_length = max(item["target_audio_length"].item() for item in batch) + target_audios_padded = [ + torch.nn.functional.pad(item["target_audio"], (0, max_target_audio_length - item["target_audio"].size(0))) + for item in batch + ] + max_context_audio_length = max(item["context_audio_length"].item() for item in batch) + context_audios_padded = [ + torch.nn.functional.pad( + item["context_audio"], (0, max_context_audio_length - item["context_audio"].size(0)) + ) + for item in batch + ] + output_dict = { + # target audio + "target_audio_filepath": [item["target_audio_filepath"] for item in batch], + "target_audios": torch.stack(target_audios_padded), + "target_audio_lengths": torch.stack([item["target_audio_length"] for item in batch]), + "target_audio_durations": [item["target_audio_duration"] for item in batch], + # context audio + "context_audio_filepath": [item["context_audio_filepath"] for item in batch], + "context_audios": torch.stack(context_audios_padded), + "context_audio_lengths": torch.stack([item["context_audio_length"] for item in batch]), + "context_audio_durations": [item["context_audio_duration"] for item in batch], + "context_speaker_similarity": [item["context_speaker_similarity"] for item in batch], + # metadata + "speaker": [item["speaker"] for item in batch], + "text": [item["text"] for item in batch], + "language": [item["language"] for item in batch], + } + # Extra useful metadata may exist in some manifests, so better to keep them for future usage. + for meta_field in MAYBE_EXTRA_METADATA_IN_MANIFEST: + if meta_field not in batch[0]: + continue + output_dict[meta_field] = [item[meta_field] for item in batch] + return output_dict + + @staticmethod + def _copy_maybe_extra_metadata(input_dict: dict, output_dict: dict): + # Extra useful metadata may exist in some manifests, so better to keep them for future usage. + for meta_field in MAYBE_EXTRA_METADATA_IN_MANIFEST: + if meta_field in input_dict: + output_dict[meta_field] = input_dict[meta_field] + return output_dict + + +class CodecExtractor(pl.LightningModule): + def __init__(self, model_path: str): + super().__init__() + self.codec_model = AudioCodecModel.restore_from(restore_path=model_path, strict=False) + self.codec_model.eval() + + def forward(self, batch): + with torch.no_grad(): + target_codes, target_codes_lengths = self.codec_model.encode( + audio=batch["target_audios"], audio_len=batch["target_audio_lengths"] + ) + context_codes, context_codes_lengths = self.codec_model.encode( + audio=batch["context_audios"], audio_len=batch["context_audio_lengths"] + ) + return { + "target_codes": target_codes.cpu().type(torch.int16), + "target_codes_lengths": target_codes_lengths, + "context_codes": context_codes.cpu().type(torch.int16), + "context_codes_lengths": context_codes_lengths, + } + + def predict_step(self, batch, batch_idx): + codes_dict = self(batch) + target_codes = [ + codes[:, :codes_length] + for codes, codes_length in zip(codes_dict["target_codes"], codes_dict["target_codes_lengths"]) + ] + context_codes = [ + codes[:, :codes_length] + for codes, codes_length in zip(codes_dict["context_codes"], codes_dict["context_codes_lengths"]) + ] + batch.update( + { + "target_codes": target_codes, + "context_codes": context_codes, + } + ) + return batch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--manifest", type=str) + parser.add_argument("--audio_base_dir", type=str) + parser.add_argument("--save_dir", type=str) + parser.add_argument("--codec_model_name", type=str, default="21fpsCausalDecoder") + parser.add_argument("--codec_model_path", type=str) + parser.add_argument("--codec_frame_rate", type=float, default=21.5) + parser.add_argument("--pad_multiple", type=int, default=1024) + parser.add_argument("--sample_rate", type=int, default=22050) + parser.add_argument("--devices", type=int, default=-1) + parser.add_argument("--num_nodes", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=48) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--shard_size", type=int, default=4096) + args = parser.parse_args() + + codec_extractor = CodecExtractor(args.codec_model_path) + + dataset = AudioDataset( + manifest=args.manifest, + audio_base_dir=args.audio_base_dir, + sample_rate=args.sample_rate, + pad_multiple=args.pad_multiple, + ) + dataloader = DataLoader( + dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, collate_fn=dataset.collate_fn + ) + + # Note that context_recording would be stored using AudioTarWriter. + pred_writer = SharPredictionWriter( + output_dir=args.save_dir, + codec_model_name=args.codec_model_name, + audio_base_dir=args.audio_base_dir, + codec_frame_rate=args.codec_frame_rate, + fields={ + "recording": "flac", + f"codes_{args.codec_model_name}": "numpy", + f"context_codes_{args.codec_model_name}": "numpy", + }, + shard_size=args.shard_size, + ) + + trainer = Trainer( + devices=args.devices, + accelerator="gpu", + strategy=DDPStrategy(find_unused_parameters=False), + num_nodes=args.num_nodes, + logger=False, + ) + # add writer callback to all gather batched predictions and write into shards. + trainer.callbacks.append(pred_writer) + + trainer.predict(codec_extractor, dataloaders=dataloader, return_predictions=False) From 7f54c6dc3feb67d6a26cbe063eb5e932675c2e18 Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 11 Apr 2025 14:44:52 -0400 Subject: [PATCH 017/113] Merge in Jason's dev changes (#57) * add fix to infer script Signed-off-by: Jason * add no context option Signed-off-by: Jason * add nemo option to infer script Signed-off-by: Jason * add in latest bf16 changes from Edresson Signed-off-by: Jason * add comment Signed-off-by: Jason * enforce codec precision for now Signed-off-by: Jason * fix autocast bug Signed-off-by: Jason * another bug fix Signed-off-by: Jason * clean PR Signed-off-by: Jason * change hardcoded epsilon Signed-off-by: Jason * infer changes Signed-off-by: Jason * address review Signed-off-by: Jason --------- Signed-off-by: Jason --- .../tts/losses/audio_codec_loss.py | 4 +- nemo/collections/tts/models/audio_codec.py | 39 +++++--- nemo/collections/tts/models/magpietts.py | 60 ++++++------ .../tts/modules/audio_codec_modules.py | 87 ++++++++++------- .../tts/modules/transformer_2501.py | 2 +- scripts/magpietts/infer_and_evaluate.py | 97 ++++++++++++++----- 6 files changed, 186 insertions(+), 103 deletions(-) mode change 100644 => 100755 nemo/collections/tts/modules/audio_codec_modules.py diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py index 6db3e30595c6..e87f4a959443 100644 --- a/nemo/collections/tts/losses/audio_codec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -230,8 +230,8 @@ def output_types(self): @typecheck() def forward(self, audio_real, audio_gen, audio_len): spec_len = (audio_len // self.hop_length) + 1 - spec_real = self._compute_spectrogram(audio=audio_real, spec_len=spec_len) - spec_gen = self._compute_spectrogram(audio=audio_gen, spec_len=spec_len) + spec_real = self._compute_spectrogram(audio=audio_real.float(), spec_len=spec_len).to(audio_gen.dtype) + spec_gen = self._compute_spectrogram(audio=audio_gen.float(), spec_len=spec_len).to(audio_gen.dtype) loss = self.loss_fn(predicted=spec_gen, target=spec_real, target_len=spec_len) return loss diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 33b9a80125b7..134cf1e0f492 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -32,7 +32,7 @@ SISDRLoss, TimeDomainLoss, ) -from nemo.collections.tts.modules.audio_codec_modules import ResNetSpeakerEncoder +from nemo.collections.tts.modules.audio_codec_modules import ResNetSpeakerEncoder, default_precision from nemo.collections.tts.modules.common import GaussianDropout from nemo.collections.tts.parts.utils.callbacks import LoggingCallback from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers @@ -191,6 +191,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.lr_schedule_interval = None self.automatic_optimization = False + @property + def dtype(self): + return next(self.parameters()).dtype + def state_dict(self, destination=None, prefix='', keep_vars=False): if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} @@ -307,7 +311,9 @@ def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Te raise ValueError("Cannot quantize without quantizer") # vector quantizer is returning [C, B, T], where C is the number of codebooks - tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) + with default_precision(torch.float32): + # vector quantizer is returning [C, B, T], where C is the number of codebooks + tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) # use batch first for the output tokens = rearrange(tokens, 'C B T -> B C T') return tokens @@ -336,7 +342,9 @@ def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Te # vector quantizer is using [C, B, T], where C is the number of codebooks tokens = rearrange(tokens, 'B C T -> C B T') - dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len) + with default_precision(torch.float32): + dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len) + dequantized = dequantized.to(self.dtype) # make sure dequantized is in the right dtype return dequantized @typecheck( @@ -389,6 +397,7 @@ def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch. """ # Convert a discrete representation to a dequantized vector for each frame dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len) + dequantized = dequantized.to(self.dtype) # make sure that the dequantized is in the model dtype # Apply decoder to obtain time-domain audio for each frame audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len) @@ -459,15 +468,19 @@ def _process_batch(self, batch): encoded = self.encoder_noise(encoded) if self.vector_quantizer: - if self.vector_quantizer_has_commit_loss: - encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) - else: - encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len) - commit_loss = 0.0 + with default_precision(torch.float32): + if self.vector_quantizer_has_commit_loss: + encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) + else: + encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len) + commit_loss = 0.0 + + encoded = encoded.to(encoded.dtype) # make sure encoded is converted to the right dtype else: commit_loss = 0.0 # [B, T] + encoded = encoded.to(self.dtype) # make sure vector quantizer output is in the model dtype audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len) return audio, audio_len, audio_gen, commit_loss @@ -508,7 +521,9 @@ def training_step(self, batch, batch_idx): generator_losses = [] - loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + # stft does not support bf16, so make it run in fp32 + loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) + if self.mel_loss_l1_scale: metrics["g_loss_mel_l1"] = loss_mel_l1 generator_losses.append(self.mel_loss_l1_scale * loss_mel_l1) @@ -517,7 +532,7 @@ def training_step(self, batch, batch_idx): generator_losses.append(self.mel_loss_l2_scale * loss_mel_l2) if self.stft_loss_scale: - loss_stft = self.stft_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) metrics["g_loss_stft"] = loss_stft generator_losses.append(self.stft_loss_scale * loss_stft) @@ -594,8 +609,8 @@ def on_train_epoch_end(self): def validation_step(self, batch, batch_idx): audio, audio_len, audio_gen, _ = self._process_batch(batch) - loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) - loss_stft = self.stft_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) + loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) loss_si_sdr = self.si_sdr_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index d2beff1ae0ca..2ce6cff75c02 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -254,39 +254,39 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): raise ValueError(f"Received audio_type of {audio_type}. Must be `target` or `context`") self._codec_model.eval() - with torch.cuda.amp.autocast(enabled=False): - with torch.no_grad(): - codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len) - # Add a timestep to begining and end of codes tensor - bos_tensor = torch.full( - (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device - ) - pad_tensor = torch.full( - (codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device - ) # 0 is the padding token in the audio codebook - codes = torch.cat([bos_tensor, codes, pad_tensor], dim=-1) - # codes: (B, C, T') - # codes_len: (B,) - for idx in range(codes.size(0)): - codes[idx, :, codes_len[idx] + 1] = audio_eos_id - codes_len = codes_len + 2 - - return codes.long(), codes_len.long() + with torch.no_grad(), torch.autocast(device_type=audio.device.type, dtype=torch.float32): + codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len) + # Add a timestep to begining and end of codes tensor + bos_tensor = torch.full( + (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device + ) + pad_tensor = torch.full( + (codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device + ) # 0 is the padding token in the audio codebook + codes = torch.cat([bos_tensor, codes, pad_tensor], dim=-1) + # codes: (B, C, T') + # codes_len: (B,) + for idx in range(codes.size(0)): + codes[idx, :, codes_len[idx] + 1] = audio_eos_id + codes_len = codes_len + 2 + + return codes.long(), codes_len.long() def codes_to_audio(self, codes, codes_len): # codes: (B, C, T') # codes_len: (B,) self._codec_model.eval() - with torch.cuda.amp.autocast(enabled=False): - with torch.no_grad(): - # Replace eos and bos tokens with padding in codes tensor - codes[codes == self.audio_bos_id] = 0 # zero is the padding token in the audio codebook - codes[codes == self.audio_eos_id] = 0 - # self.additional_models['codec'] = self.additional_models['codec'].to(codes.device) - audio, audio_len = self._codec_model.decode(tokens=codes, tokens_len=codes_len) - # audio: (B, T) - # audio_len: (B,) - return audio, audio_len + with torch.no_grad(), torch.autocast(device_type=codes.device.type, dtype=torch.float32): + # Make a copy to avoid modifying the original tensor if it's used elsewhere + codes_copy = codes.clone() + # Replace eos and bos tokens with padding in the copied tensor + codes_copy[codes == self.audio_bos_id] = 0 # zero is the padding token + codes_copy[codes == self.audio_eos_id] = 0 + # Pass the modified integer token IDs + audio, audio_len = self._codec_model.decode(tokens=codes_copy, tokens_len=codes_len) + # audio: (B, T) + # audio_len: (B,) + return audio, audio_len def embed_audio_tokens(self, audio_tokens): # audio_tokens: (B, C, T') @@ -1013,6 +1013,8 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): batch_output = self.process_batch(batch, mode="val") + # self.process_batch returns a dict. We currently only log "logits" which come from the parallel prediction + # head. If we use local_transformer, then the local_transformer returns "local_transformer_logits" loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] alignment_loss = batch_output['alignment_loss'] @@ -1033,7 +1035,7 @@ def validation_step(self, batch, batch_idx): if batch_idx == 0 and self.global_rank == 0: self.log_train_val_audio_example( logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens - ) + ) # Currently, only logs parallel prediction (logits). No local transformer results if ( self.model_type != 'decoder_pretrain_synthesizer' and len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1 diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py old mode 100644 new mode 100755 index baf9a1648282..b6b076f8bd2d --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -54,6 +54,16 @@ HAVE_FSSPEC = False +from contextlib import contextmanager +@contextmanager +def default_precision(dtype=torch.float32): + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(default_dtype) + def get_padding(kernel_size: int, dilation: int = 1) -> int: return (kernel_size * dilation - dilation) // 2 @@ -407,40 +417,41 @@ def forward(self, x, l2_norm=False): Shapes: - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` """ - x.squeeze_(1) - # if you torch spec compute it otherwise use the mel spec computed by the AP - if self.use_torch_spec: - x = self.torch_spec(x) + with default_precision(torch.float32): + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) - if self.log_input: - x = (x + 1e-6).log() - x = self.instancenorm(x).unsqueeze(1) + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) - x = self.conv1(x) - x = self.relu(x) - x = self.bn1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) - x = x.reshape(x.size()[0], -1, x.size()[-1]) + x = x.reshape(x.size()[0], -1, x.size()[-1]) - w = self.attention(x) + w = self.attention(x) - if self.encoder_type == "SAP": - x = torch.sum(x * w, dim=2) - elif self.encoder_type == "ASP": - mu = torch.sum(x * w, dim=2) - sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) - x = torch.cat((mu, sg), 1) + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) - x = x.view(x.size()[0], -1) - x = self.fc(x) + x = x.view(x.size()[0], -1) + x = self.fc(x) - if l2_norm: - x = torch.nn.functional.normalize(x, p=2, dim=1) + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) return x def get_torch_mel_spectrogram_class(self, audio_config): @@ -602,12 +613,12 @@ def _get_extra_padding_for_conv1d( hidden_states: torch.Tensor, ) -> torch.Tensor: """See `pad_for_conv1d`.""" - length = hidden_states.shape[-1] - n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 - n_frames = torch.ceil(n_frames).to(torch.int64) - 1 - ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total - - return ideal_length - length + with default_precision(torch.float32): + length = hidden_states.shape[-1] + n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 + n_frames = torch.ceil(n_frames).to(torch.int64) - 1 + ideal_length = (n_frames * self.stride).long() + self.kernel_size - self.padding_total + return (ideal_length - length).long() @staticmethod # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._pad1d @@ -999,7 +1010,8 @@ def output_types(self): def forward(self, audio): scores_list = [] fmap_list = [] - spec = self.compute_stft(audio) + # run spec compute on fp32 and convert out to the model training type + spec = self.compute_stft(audio.float()).to(audio.dtype) for band, disc in zip(self.stft_bands, self.discriminators): spec_band = spec[:, :, :, band[0] : band[1]] score, fmap = disc(spec=spec_band) @@ -1765,7 +1777,8 @@ def forward(self, audio, audio_len): out = res_layer(inputs=out, input_len=encoded_len) out = act(out) - encoded_len = encoded_len // down_sample_rate + with default_precision(torch.float32): + encoded_len = (encoded_len // down_sample_rate).long() # [B, 2 * C, T / down_sample_rate] out = down_sample_conv(inputs=out, input_len=encoded_len) @@ -1886,7 +1899,8 @@ def forward(self, audio, audio_len): out = res_layer(inputs=out, input_len=encoded_len) out = act(out) - encoded_len = encoded_len // down_sample_rate + with default_precision(torch.float32): + encoded_len = (encoded_len // down_sample_rate).long() # [B, 2 * C, T / down_sample_rate] out = down_sample_conv(inputs=out, input_len=encoded_len) @@ -2012,7 +2026,8 @@ def forward(self, inputs, input_len): for act, res_layer, up_sample_conv, up_sample_rate in zip( self.activations, self.res_layers, self.up_sample_conv_layers, self.up_sample_rates ): - audio_len = audio_len * up_sample_rate + with default_precision(torch.float32): + audio_len = (audio_len * up_sample_rate).long() out = act(out) # [B, C / 2, T * up_sample_rate] out = up_sample_conv(inputs=out, input_len=audio_len) diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index b72cef6d3f92..2fe354df7765 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -227,7 +227,7 @@ def attn_naive( # attn_prior or square mask or vanilla attention if attn_prior is not None: - eps = 1e-8 + eps = torch.finfo(attn_prior.dtype).tiny attn_prior = attn_prior[:, :T] # trim for inference attn_prior = attn_prior[:, None] attn_prior_log = torch.log(attn_prior + eps) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 634bafda5e5a..82215085963e 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -31,9 +31,28 @@ def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0 metrics[key] = "{:.4f} +/- {:.4f}".format(mean, confidence_interval) return metrics +def update_config(model_cfg, codecmodel_path): + ''' helper function to rename older yamls from t5 to magpie ''' + model_cfg.codecmodel_path = codecmodel_path + if hasattr(model_cfg, 'text_tokenizer'): + # Backward compatibility for models trained with absolute paths in text_tokenizer + model_cfg.text_tokenizer.g2p.phoneme_dict = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + model_cfg.text_tokenizer.g2p.heteronyms = "scripts/tts_dataset_files/heteronyms-052722" + model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0 + model_cfg.train_ds = None + model_cfg.validation_ds = None + if "t5_encoder" in model_cfg: + model_cfg.encoder = model_cfg.t5_encoder + del model_cfg.t5_encoder + if "t5_decoder" in model_cfg: + model_cfg.decoder = model_cfg.t5_decoder + del model_cfg.t5_decoder + return model_cfg + def run_inference( hparams_file, checkpoint_file, + nemo_file, datasets, out_dir, temperature, @@ -54,33 +73,37 @@ def run_inference( confidence_level=0.95, use_local_transformer=False ): - # import ipdb; ipdb.set_trace() - model_cfg = OmegaConf.load(hparams_file).cfg + # Load model + if hparams_file is not None: + model_cfg = OmegaConf.load(hparams_file) + if "cfg" in model_cfg: + model_cfg = model_cfg.cfg - with open_dict(model_cfg): - model_cfg.codecmodel_path = codecmodel_path - if hasattr(model_cfg, 'text_tokenizer'): - # Backward compatibility for models trained with absolute paths in text_tokenizer - model_cfg.text_tokenizer.g2p.phoneme_dict = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" - model_cfg.text_tokenizer.g2p.heteronyms = "scripts/tts_dataset_files/heteronyms-052722" - model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0 - model_cfg.train_ds = None - model_cfg.validation_ds = None + with open_dict(model_cfg): + model_cfg = update_config(model_cfg, codecmodel_path) + model = MagpieTTS_Model(cfg=model_cfg) + model.use_kv_cache_for_inference = True - model = MagpieTTSModel(cfg=model_cfg) - model.use_kv_cache_for_inference = True + # Load weights from checkpoint file + print("Loading weights from checkpoint") + ckpt = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(ckpt['state_dict']) + checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] + elif nemo_file is not None: + model_cfg = MagpieTTS_Model.restore_from(nemo_file, return_config=True) + with open_dict(model_cfg): + model_cfg = update_config(model_cfg, codecmodel_path) + model = MagpieTTS_Model.restore_from(nemo_file, override_config_path=model_cfg) + model.use_kv_cache_for_inference = True + checkpoint_name = nemo_file.split("/")[-1].split(".nemo")[0] + else: + raise ValueError("Need a checkpoint") - # Load weights from checkpoint file - print("Loading weights from checkpoint") - ckpt = torch.load(checkpoint_file, weights_only=False) - model.load_state_dict(ckpt['state_dict']) print("Loaded weights.") model.cuda() model.eval() - # import ipdb; ipdb.set_trace() - checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}_LT_{}_sv_{}".format( checkpoint_name, temperature, @@ -226,7 +249,7 @@ def run_inference( with open(os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}.json"), "w") as f: json.dump(mean_rtf_metrics, f, indent=4) - all_experiment_csv = os.path.join(out_dir, "all_experiment_metrics.csv") + all_experiment_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") if not os.path.exists(all_experiment_csv): with open(all_experiment_csv, "w") as f: f.write("checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative\n") @@ -248,11 +271,11 @@ def run_inference( f.write(f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']}\n") print(f"Wrote metrics with CI for {checkpoint_name} and {dataset} to {all_experiment_csv_with_ci}") - def main(): parser = argparse.ArgumentParser(description='Experiment Evaluation') parser.add_argument('--hparams_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_hparams.yaml,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_hparams.yaml") parser.add_argument('--checkpoint_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_epoch302.ckpt,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_epoch305.ckpt") + parser.add_argument('--nemo_file', type=str, default=None) parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo") parser.add_argument('--datasets', type=str, default="libri_unseen_test_12.5") parser.add_argument('--base_exp_dir', type=str, default="/datap/misc/eosmountedresson/") @@ -286,7 +309,7 @@ def main(): if args.apply_prior_to_layers is not None: apply_prior_to_layers = [int(l.strip()) for l in args.apply_prior_to_layers.split(",")] - if (args.hparams_files is not None) and (args.checkpoint_files is not None) and (args.hparams_files != "null"): + if (args.hparams_files is not None) and (args.checkpoint_files is not None) and (args.hparams_files != "null") and (args.checkpoint_files != "null"): hparam_files = args.hparams_files.split(",") checkpoint_files = args.checkpoint_files.split(",") print("Running inference for hparams files: ", hparam_files) @@ -296,6 +319,7 @@ def main(): run_inference( hparams_file=hparams_file, checkpoint_file=checkpoint_file, + nemo_file=None, datasets=args.datasets.split(","), out_dir=args.out_dir, temperature=args.temperature, @@ -317,6 +341,33 @@ def main(): use_local_transformer=args.use_local_transformer ) return + elif (args.nemo_file is not None): + nemo_file = args.nemo_file + print("Running inference for nemo file: ", nemo_file) + run_inference( + hparams_file=None, + checkpoint_file=None, + nemo_file=nemo_file, + datasets=args.datasets.split(","), + out_dir=args.out_dir, + temperature=args.temperature, + topk=args.topk, + codecmodel_path=args.codecmodel_path, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + batch_size=args.batch_size, + sv_model=args.sv_model, + asr_model_name=args.asr_model_name, + num_repeats=args.num_repeats, + apply_attention_prior=args.apply_attention_prior, + attention_prior_epsilon=args.attention_prior_epsilon, + attention_prior_lookahead_window=args.attention_prior_lookahead_window, + estimate_alignment_from_layers=estimate_alignment_from_layers, + apply_prior_to_layers=apply_prior_to_layers, + start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, + confidence_level=args.confidence_level, + use_local_transformer=args.use_local_transformer + ) else: BASE_EXP_DIR = args.base_exp_dir DRACO_EXP_DIR = args.draco_exp_dir @@ -352,12 +403,12 @@ def main(): print(f"Running command: {scp_command_hparams}") os.system(scp_command_hparams) print("Copied hparams file.") - # import ipdb; ipdb.set_trace() print("Hparams file path: ", hparams_copy_path) print("Checkpoint file path: ", checkpoint_copy_path) run_inference( hparams_copy_path, checkpoint_copy_path, + nemo_file=None, datasets=args.datasets.split(","), out_dir=args.out_dir, temperature=args.temperature, From 718173ae9ac111fe90320d0ae945834ef0f4532c Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Tue, 15 Apr 2025 10:36:01 -0700 Subject: [PATCH 018/113] Bug fix in context text embedding initialization (#58) * bug fix in context text embedding initialization Signed-off-by: Paarth Neekhara * bug fixes in infer and evaluate Signed-off-by: Paarth Neekhara --------- Signed-off-by: Paarth Neekhara --- nemo/collections/tts/models/magpietts.py | 6 +++--- scripts/magpietts/infer_and_evaluate.py | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 2ce6cff75c02..90608349547a 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -124,9 +124,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.context_audio_bos_id = cfg.num_audio_tokens_per_codebook - 2 # For backward compatibility self.context_audio_eos_id = cfg.num_audio_tokens_per_codebook - 1 # For backward compatibility - if self.use_text_conditioning_encoder: - self.context_text_embedding = nn.Embedding(self.text_conditioning_tokenizer.vocab_size, cfg.embedding_dim) - self.model_type = cfg.get('model_type', 'single_encoder_sv_tts') if self.model_type == 'decoder_context_tts': @@ -140,6 +137,9 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg=cfg, trainer=trainer) + if self.use_text_conditioning_encoder: + self.context_text_embedding = nn.Embedding(self.text_conditioning_tokenizer.vocab_size, cfg.embedding_dim) + audio_embeddings = [] for _ in range(cfg.num_audio_codebooks): audio_embeddings.append(nn.Embedding(cfg.num_audio_tokens_per_codebook, cfg.embedding_dim)) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 82215085963e..d35019562350 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -82,7 +82,7 @@ def run_inference( with open_dict(model_cfg): model_cfg = update_config(model_cfg, codecmodel_path) - model = MagpieTTS_Model(cfg=model_cfg) + model = MagpieTTSModel(cfg=model_cfg) model.use_kv_cache_for_inference = True # Load weights from checkpoint file @@ -91,10 +91,10 @@ def run_inference( model.load_state_dict(ckpt['state_dict']) checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] elif nemo_file is not None: - model_cfg = MagpieTTS_Model.restore_from(nemo_file, return_config=True) + model_cfg = MagpieTTSModel.restore_from(nemo_file, return_config=True) with open_dict(model_cfg): model_cfg = update_config(model_cfg, codecmodel_path) - model = MagpieTTS_Model.restore_from(nemo_file, override_config_path=model_cfg) + model = MagpieTTSModel.restore_from(nemo_file, override_config_path=model_cfg) model.use_kv_cache_for_inference = True checkpoint_name = nemo_file.split("/")[-1].split(".nemo")[0] else: @@ -164,7 +164,9 @@ def run_inference( context_duration_max=context_durration_max, ) assert len(test_dataset) == len(manifest_records), "Dataset length and manifest length should be the same. Dataset length: {}, Manifest length: {}".format(len(test_dataset), len(manifest_records)) - test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test') + + test_dataset.text_tokenizer = model.tokenizer + test_dataset.text_conditioning_tokenizer = model.text_conditioning_tokenizer test_data_loader = torch.utils.data.DataLoader( test_dataset, From 7e2cdca74a866ecefdbe01c0076ad9b5d140ac61 Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 16 Apr 2025 09:45:34 -0400 Subject: [PATCH 019/113] update defaul params in config (#59) Signed-off-by: Jason --- .../tts/conf/magpietts/magpietts_dc_en.yaml | 179 ++++++++++++++++++ examples/tts/conf/magpietts/magpietts_en.yaml | 34 ++-- ...se_en.yaml => magpietts_lhotse_dc_en.yaml} | 32 +--- 3 files changed, 202 insertions(+), 43 deletions(-) create mode 100644 examples/tts/conf/magpietts/magpietts_dc_en.yaml rename examples/tts/conf/magpietts/{magpietts_lhotse_en.yaml => magpietts_lhotse_dc_en.yaml} (85%) diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml new file mode 100644 index 000000000000..58f09cf8da45 --- /dev/null +++ b/examples/tts/conf/magpietts/magpietts_dc_en.yaml @@ -0,0 +1,179 @@ +name: Magpie-TTS-EN + +max_epochs: ??? +# Adjust batch size based on GPU memory +batch_size: 16 +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. +weighted_sampling_steps_per_epoch: null + +# Dataset metadata for each manifest +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +train_ds_meta: ??? +val_ds_meta: ??? + +# Modify these values based on your sample rate +sample_rate: 22050 + +model: + model_type: "decoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer + use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. + context_duration_min: 3.0 + context_duration_max: 5.0 + num_audio_codebooks: 8 + num_audio_tokens_per_codebook: 2018 # 2016 from codec + Keep at least 2 extra for eos/bos ids + codec_model_downsample_factor: 1024 + load_cached_codes_if_available: true + prior_scaling_factor: 0.5 + prior_end_step: 12000 + prior_scaledown_start_step: 8000 # Prior will always be on before this step. + indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. + alignment_loss_scale: 0.002 + embedding_dim: 768 + codecmodel_path: ??? + max_epochs: ${max_epochs} + steps_per_epoch: ${weighted_sampling_steps_per_epoch} + sample_rate: ${sample_rate} + cfg_unconditional_prob: 0.1 + + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + use_local_transformer: false + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + + text_tokenizers: # Add more languages for multi-lingual TTS + english_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: false + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + + train_ds: + dataset: + _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset + dataset_meta: ${train_ds_meta} + weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} + sample_rate: ${sample_rate} + # speaker_path: ${speaker_path} + min_duration: 0.2 + max_duration: 20.0 + + dataloader_params: + batch_size: ${batch_size} + num_workers: 4 + drop_last: true + pin_memory: true + + validation_ds: + dataset: + _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset + dataset_meta: ${val_ds_meta} + sample_rate: ${sample_rate} + # speaker_path: ${speaker_path} + min_duration: 0.2 + max_duration: 20.0 + + dataloader_params: + batch_size: ${batch_size} + num_workers: 4 + pin_memory: true + + encoder: + n_layers: 6 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: true + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + decoder: + n_layers: 12 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 1 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: true + xa_d_memory: 768 + xa_n_heads: 1 + is_causal: true + apply_norm_to_cond: true + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + optim: + _target_: torch.optim.AdamW + lr: 1e-4 + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: bf16-mixed + max_epochs: ${max_epochs} + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 1 + num_sanity_val_steps: 0 + benchmark: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null + resume: true + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + resume_if_exists: true + resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index 12deab788047..2edbf0634e9e 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -17,27 +17,26 @@ sample_rate: 22050 model: model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. + use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts context_duration_min: 3.0 - context_duration_max: 8.0 - speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise + context_duration_max: 5.0 num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids + num_audio_tokens_per_codebook: 2018 # 2016 from codec + Keep at least 2 extra for eos/bos ids codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: 0.5 prior_end_step: 12000 prior_scaledown_start_step: 8000 # Prior will always be on before this step. indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. - alignment_loss_scale: 0.0 + alignment_loss_scale: 0.002 embedding_dim: 768 codecmodel_path: ??? max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} + cfg_unconditional_prob: 0.1 # Alignment encoder parameters, to binarize the prior # This is used for attention-constrained training and inference @@ -85,7 +84,7 @@ model: weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} sample_rate: ${sample_rate} # speaker_path: ${speaker_path} - min_duration: 0.5 + min_duration: 0.2 max_duration: 20.0 dataloader_params: @@ -100,12 +99,12 @@ model: dataset_meta: ${val_ds_meta} sample_rate: ${sample_rate} # speaker_path: ${speaker_path} - min_duration: 0.5 + min_duration: 0.2 max_duration: 20.0 dataloader_params: batch_size: ${batch_size} - num_workers: 0 + num_workers: 4 pin_memory: true encoder: @@ -117,7 +116,7 @@ model: p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: false - is_causal: False + is_causal: true apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true @@ -141,12 +140,12 @@ model: d_model: 768 d_ffn: 3072 sa_n_heads: 12 - kernel_size: 3 + kernel_size: 1 p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: true xa_d_memory: 768 - xa_n_heads: 12 + xa_n_heads: 1 is_causal: true apply_norm_to_cond: true apply_norm_out: true @@ -154,9 +153,8 @@ model: use_learnable_pos_emb: true optim: - _target_: torch.optim.Adam - lr: 2e-4 - betas: [0.8, 0.99] + _target_: torch.optim.AdamW + lr: 1e-4 sched: name: ExponentialLR @@ -167,14 +165,14 @@ trainer: devices: -1 accelerator: gpu strategy: ddp_find_unused_parameters_true - precision: 32 + precision: bf16-mixed max_epochs: ${max_epochs} accumulate_grad_batches: 1 enable_checkpointing: False # Provided by exp_manager logger: false # Provided by exp_manager log_every_n_steps: 100 - val_check_interval: 500 - # check_val_every_n_epoch: 10 + check_val_every_n_epoch: 1 + num_sanity_val_steps: 0 benchmark: false exp_manager: diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml similarity index 85% rename from examples/tts/conf/magpietts/magpietts_lhotse_en.yaml rename to examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index f61cfc4c337a..42f0887dfc0b 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -5,15 +5,12 @@ sample_rate: 22_050 model: use_lhotse: true - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer + model_type: "decoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts context_duration_min: 3.0 context_duration_max: 5.0 - speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2_048 # Keep atleast 2 extra for eos/bos ids + num_audio_tokens_per_codebook: 2_018 # Keep atleast 2 extra for eos/bos ids codec_model_downsample_factor: 1_024 codec_model_name: "21fpsCausalDecoder" load_cached_codes_if_available: true @@ -125,21 +122,7 @@ model: p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: false - is_causal: false - apply_norm_out: true - max_length_causal_mask: 2_048 - use_learnable_pos_emb: true - - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 - d_model: 768 - d_ffn: 3_072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: false + is_causal: true apply_norm_out: true max_length_causal_mask: 2_048 use_learnable_pos_emb: true @@ -149,12 +132,12 @@ model: d_model: 768 d_ffn: 3_072 sa_n_heads: 12 - kernel_size: 3 + kernel_size: 1 p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: true xa_d_memory: 768 - xa_n_heads: 12 + xa_n_heads: 1 is_causal: true apply_norm_to_cond: true apply_norm_out: true @@ -162,9 +145,8 @@ model: use_learnable_pos_emb: true optim: - _target_: torch.optim.Adam - lr: 2e-4 - betas: [0.8, 0.99] + _target_: torch.optim.AdamW + lr: 1e-4 sched: name: ExponentialLR From aae4e6058c52d3a0b899f4aaf7f4148bb2536e6e Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Thu, 17 Apr 2025 10:47:17 -0700 Subject: [PATCH 020/113] magpie top k bug fix (#60) Signed-off-by: Paarth Neekhara --- nemo/collections/tts/models/magpietts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 90608349547a..6e6ec76739b4 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -427,7 +427,7 @@ def sample_codes_from_local_transformer(self, dec_output, temperature=0.7, topk= indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(-1) # (B, num_tokens_per_codebook) codebook_logits_rescored = codebook_logits.clone() codebook_logits_rescored[indices_to_remove] = float('-inf') - codebook_probs = torch.softmax(codebook_logits / temperature, dim=-1) # (B, num_tokens_per_codebook) + codebook_probs = torch.softmax(codebook_logits_rescored / temperature, dim=-1) # (B, num_tokens_per_codebook) codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) if use_cfg: codebook_preds[actual_batch_size:] = codebook_preds[:actual_batch_size] @@ -462,7 +462,7 @@ def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80, codebook_logits_rescored = codebook_logits.clone() codebook_logits_rescored[indices_to_remove] = float('-inf') - codebook_probs = torch.softmax(codebook_logits / temperature, dim=-1) # (B, num_tokens_per_codebook) + codebook_probs = torch.softmax(codebook_logits_rescored / temperature, dim=-1) # (B, num_tokens_per_codebook) codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) all_preds.append(codebook_preds) all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks) From 23e299a0bd14b666543b4bbcc7783f783acb0bd3 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 22 Apr 2025 13:26:02 -0700 Subject: [PATCH 021/113] Bugfix: num_audio_tokens_per_codebook (#62) Make sure to reserve enough tokens for special uses like EOS/BOS. WARNING: old models will be incompatible with the updated inference YAMLs and will need to override the num_audio_tokens_per_codebook to the value they were trained with. --- examples/tts/conf/magpietts/magpietts_dc_en.yaml | 2 +- examples/tts/conf/magpietts/magpietts_en.yaml | 2 +- examples/tts/conf/magpietts/magpietts_inference_en.yaml | 2 +- .../tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml | 2 +- examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml | 2 +- examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml index 58f09cf8da45..6a6c87ae12c9 100644 --- a/examples/tts/conf/magpietts/magpietts_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_dc_en.yaml @@ -21,7 +21,7 @@ model: context_duration_min: 3.0 context_duration_max: 5.0 num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2018 # 2016 from codec + Keep at least 2 extra for eos/bos ids + num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: 0.5 diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index 2edbf0634e9e..607c4e7a66bf 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -23,7 +23,7 @@ model: context_duration_min: 3.0 context_duration_max: 5.0 num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2018 # 2016 from codec + Keep at least 2 extra for eos/bos ids + num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: 0.5 diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_inference_en.yaml index 38b383fcfc58..c54221b9a63f 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_en.yaml @@ -32,7 +32,7 @@ model: speaker_emb_dim: 192 max_decoder_steps: 500 num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids + num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: null diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml index 5cbaf589ed1b..cd82580b5ec5 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml @@ -32,7 +32,7 @@ model: speaker_emb_dim: 192 max_decoder_steps: 500 num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids + num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: null diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index 42f0887dfc0b..714bc783e75e 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -10,7 +10,7 @@ model: context_duration_min: 3.0 context_duration_max: 5.0 num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2_018 # Keep atleast 2 extra for eos/bos ids + num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) codec_model_downsample_factor: 1_024 codec_model_name: "21fpsCausalDecoder" load_cached_codes_if_available: true diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index e262833ba907..e315b7e1fb88 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -24,7 +24,7 @@ model: context_duration_max: 8.0 speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids + num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: 0.5 From 3e1e7fc1baaf4b0d22d25afec406ca537e6e8540 Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Thu, 24 Apr 2025 06:45:51 -0700 Subject: [PATCH 022/113] Add num_codebooks and codebook_size to codec interface (#65) Signed-off-by: Ryan --- nemo/collections/tts/models/audio_codec.py | 14 ++++++++ .../tts/modules/audio_codec_modules.py | 36 ++++++++++++++----- .../tts/modules/encodec_modules.py | 21 ++++++++++- 3 files changed, 62 insertions(+), 9 deletions(-) diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 134cf1e0f492..1a1597be138c 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -195,6 +195,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): def dtype(self): return next(self.parameters()).dtype + @property + def num_codebooks(self): + if self.vector_quantizer is None: + raise ValueError("This AudioCodecModel does not have a vector quantizer.") + + return self.vector_quantizer.num_codebooks + + @property + def codebook_size(self): + if self.vector_quantizer is None: + raise ValueError("This AudioCodecModel does not have a vector quantizer.") + + return self.vector_quantizer.codebook_size + def state_dict(self, destination=None, prefix='', keep_vars=False): if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index b6b076f8bd2d..846d5a028332 100755 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -1117,6 +1117,16 @@ def forward(self, audio_real, audio_gen): class VectorQuantizerBase(NeuralModule, ABC): + @property + @abstractmethod + def num_codebooks(self) -> int: + pass + + @property + @abstractmethod + def codebook_size(self) -> int: + pass + @property def input_types(self): return { @@ -1196,6 +1206,11 @@ def __init__(self, num_levels: List[int], eps: float = 1e-3): logging.debug('\tcodebook_size: %s', self.codebook_size) logging.debug('\teps: %s', self.eps) + @property + def num_codebooks(self): + """Returns the number of codebooks.""" + return 1 + @property def codebook_size(self): """Returns the size of the corresponding codebook.""" @@ -1392,19 +1407,24 @@ def __init__(self, num_groups: int, num_levels_per_group: List[int], **kwargs): logging.debug('\tcodebook_dim_per_group: %d', self.codebook_dim_per_group) @property - def codebook_dim(self): - """Input vector dimension.""" - return self.codebook_dim_per_group * self.num_groups + def num_codebooks(self): + """Returns the number of codebooks.""" + return self.num_groups @property - def codebook_size_per_group(self): - """Returns the size of the implicit codebook for each group.""" + def codebook_size(self): + """Returns the size of the codebook for each group.""" return self.fsqs[0].codebook_size + #@property + #def codebook_size(self): + # """Returns the size of the implicit codebook.""" + # return self.codebook_size_per_group**self.num_groups + @property - def codebook_size(self): - """Returns the size of the implicit codebook.""" - return self.codebook_size_per_group**self.num_groups + def codebook_dim(self): + """Input vector dimension.""" + return self.codebook_dim_per_group * self.num_groups @typecheck() def forward(self, inputs, input_len): diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index e9a1556ab700..a1c552e21074 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -739,6 +739,16 @@ def __init__( ] ) + @property + def num_codebooks(self): + """Returns the number of codebooks.""" + return len(self.codebooks) + + @property + def codebook_size(self): + """Returns the size of the codebook for each group.""" + return self.codebooks[0].codebook_size + # Override output types, since this quantizer returns commit_loss @property def output_types(self): @@ -837,7 +847,6 @@ class GroupResidualVectorQuantizer(VectorQuantizerBase): def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwargs): super().__init__() - self.num_codebooks = num_codebooks self.num_groups = num_groups self.codebook_dim = codebook_dim @@ -858,6 +867,16 @@ def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwa logging.debug('\tnum_codebooks_per_group: %d', self.num_codebooks_per_group) logging.debug('\tcodebook_dim_per_group: %d', self.codebook_dim_per_group) + @property + def num_codebooks(self): + """Returns the number of codebooks.""" + return self.num_groups + + @property + def codebook_size(self): + """Returns the size of the codebook for each group.""" + return self.rvqs[0].codebook_size + @property def num_codebooks_per_group(self): """Number of codebooks for each group.""" From 4b4914e445ca459d58367a956668e71d6b63bafe Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Mon, 28 Apr 2025 09:50:14 -0700 Subject: [PATCH 023/113] bug fix in _setup_test_dataloader when num workers=0 (#67) Signed-off-by: Shehzeen Hussain --- nemo/collections/tts/models/magpietts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 6e6ec76739b4..7e43d4c090f4 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -1573,7 +1573,7 @@ def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) dataset.text_tokenizer, dataset.text_conditioning_tokenizer = setup_tokenizers( all_tokenizers_config=self.cfg.text_tokenizers, - use_text_conditioning_tokenizer=self.use_kv_cache_for_inference, + use_text_conditioning_tokenizer=self.use_text_conditioning_encoder, mode='test' ) From 902bce61fc5dec1052888d5e4637a5d8c1c0e5f1 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Mon, 28 Apr 2025 09:50:24 -0700 Subject: [PATCH 024/113] =?UTF-8?q?preference=20optimization=20updates,=20?= =?UTF-8?q?trainer=20updates=20remove=20redundant=20dat=E2=80=A6=20(#51)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * preference optimization updates, trainer updates remove redundant datagen class Signed-off-by: Shehzeen Hussain * revert model pt change, add freeze_model function Signed-off-by: Shehzeen Hussain * remove redundant inference class Signed-off-by: Shehzeen Hussain * remove custom freeze model function and use lightning inbuilt freeze instead Signed-off-by: Shehzeen Hussain * added a readme for magpie preference optimization Signed-off-by: Shehzeen Hussain * change class name from MagpieTTSModelInference to MagpieTTSModelPrefDataGen Signed-off-by: Shehzeen Hussain * update class name from MagpieTTSModelPrefDataGen to MagpieTTSModelOfflinePODataGen Signed-off-by: Shehzeen Hussain --------- Signed-off-by: Shehzeen Hussain --- examples/tts/magpietts.py | 6 +- nemo/collections/tts/models/__init__.py | 7 +- nemo/collections/tts/models/magpietts.py | 190 ------------------ .../magpietts_preference_optimization.py | 37 ++-- scripts/magpietts/README_magpie_po.md | 178 ++++++++++++++++ 5 files changed, 203 insertions(+), 215 deletions(-) create mode 100644 scripts/magpietts/README_magpie_po.md diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index c96c55d3dee4..a40438f8cc28 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -18,7 +18,7 @@ from nemo.collections.tts.models import ( MagpieTTSModel, - MagpieTTSModelInference, + MagpieTTSModelOfflinePODataGen, MagpieTTSModelOfflinePO, MagpieTTSModelOnlinePO, ) @@ -60,9 +60,7 @@ def main(cfg): model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt model = MagpieTTSModelOnlinePO(cfg=model_cfg, trainer=trainer) elif cfg.get('mode', 'train') == 'test': - model = MagpieTTSModelInference(cfg=cfg.model, trainer=trainer) - # elif cfg.get('mode', 'train') == 'test': - # model = MagpieTTSModelPrefDataGen(cfg=cfg.model, trainer=trainer) + model = MagpieTTSModelOfflinePODataGen(cfg=cfg.model, trainer=trainer) else: raise NotImplementedError(f"Only train, dpo_train and test modes are supported. Got {cfg.mode}") diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index b9491c56d1a0..612275ce35cc 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -17,11 +17,11 @@ from nemo.collections.tts.models.fastpitch import FastPitchModel from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL from nemo.collections.tts.models.hifigan import HifiGanModel -from nemo.collections.tts.models.magpietts import MagpieTTSModel, MagpieTTSModelInference +from nemo.collections.tts.models.magpietts import MagpieTTSModel from nemo.collections.tts.models.magpietts_preference_optimization import ( MagpieTTSModelOfflinePO, MagpieTTSModelOnlinePO, - MagpieTTSModelPrefDataGen, + MagpieTTSModelOfflinePODataGen, ) from nemo.collections.tts.models.mixer_tts import MixerTTSModel from nemo.collections.tts.models.radtts import RadTTSModel @@ -45,8 +45,7 @@ "MixerTTSModel", "RadTTSModel", "MagpieTTSModel", - "MagpieTTSModelInference", - "MagpieTTSModelPrefDataGen", + "MagpieTTSModelOfflinePODataGen", "MagpieTTSModelOfflinePO", "MagpieTTSModelOnlinePO", "Tacotron2Model", diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 7e43d4c090f4..64151a988bbb 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -1596,193 +1596,3 @@ def setup_test_data(self, cfg): def list_available_models(cls) -> List[PretrainedModelInfo]: return [] - -class MagpieTTSModelInference(MagpieTTSModel): - """Small override of MagpieTTSModel for parallel multi-GPU inference and metrics calculation. - This class is used in 'test' mode and leverages trainer.test() for multi-GPU/multi-node inference. - Saves the predicted audio files and logs the CER/WER metrics as individual json files for each audio. - """ - - def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): - super().__init__(cfg, trainer) - if cfg.get('pref_set_language', "en") == "en": - self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( - model_name="nvidia/parakeet-tdt-1.1b" - ) - self.eval_asr_model.freeze() - - self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( - model_name='titanet_large' - ) - self.eval_speaker_verification_model.freeze() - - if cfg.get('load_whisper_model', False): - from transformers import WhisperForConditionalGeneration, WhisperProcessor - - self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") - self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") - self.whisper_model.eval() - - def transcribe_with_whisper(self, audio_filepath, language): - speech_array, sampling_rate = librosa.load(audio_filepath, sr=16000) - forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language=language) if language else None - inputs = self.whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features - inputs = inputs.to(self.device) - with torch.no_grad(): - predicted_ids = self.whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) - transcription = self.whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) - result = transcription[0] - return result - - def process_text(self, input_text): - """ - Normalizes text for CER/WER calculation. - Taken from hallucination_eval.py - """ - # Convert text to lowercase - lower_case_text = input_text.lower() - - # Remove commas from text - no_comma_text = lower_case_text.replace(",", "") - # Replace "-" with spaces - no_dash_text = no_comma_text.replace("-", " ") - no_dash_text = no_dash_text.replace("'", "") - no_dash_text = no_dash_text.replace(";", "") - no_dash_text = no_dash_text.replace(".", "") - - # Replace double spaces with single space - single_space_text = " ".join(no_dash_text.split()) - - single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) - - # @shehzeen: Added this to handle some common errors in ASR transcripts - single_space_text.replace("h t t p", "http") - single_space_text.replace("w w w", "www") - - return single_space_text - - def get_speaker_embeddings_from_filepaths(self, filepaths): - audio_batch = [] - audio_lengths = [] - for filepath in filepaths: - audio, sr = sf.read(filepath) - if sr != 16000: - audio = librosa.core.resample(audio, orig_sr=sr, target_sr=16000) - audio_tensor = torch.tensor(audio, dtype=torch.float32, device=self.device) - audio_batch.append(audio_tensor) - audio_lengths.append(audio_tensor.size(0)) - - batch_audio_lens = torch.tensor(audio_lengths, device=self.device).long() - max_audio_len = int(batch_audio_lens.max().item()) - audio_batch = stack_tensors(audio_batch, max_lens=[max_audio_len]) - - _, speaker_embeddings = self.eval_speaker_verification_model.forward( - input_signal=audio_batch, input_signal_length=batch_audio_lens - ) - - return speaker_embeddings - - def test_step(self, batch, batch_idx): - with torch.no_grad(): - test_dl_batch_size = self._test_dl.batch_size - temperature = self.cfg.get('inference_temperature', 0.7) - topk = self.cfg.get('inference_topk', 80) - use_cfg = self.cfg.get('inference_use_cfg', False) - cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) - predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( - batch, - max_decoder_steps=self.cfg.get('max_decoder_steps', 500), - temperature=temperature, - topk=topk, - use_cfg=use_cfg, - cfg_scale=cfg_scale, - ) - predicted_audio_paths = [] - audio_durations = [] - batch_invalid = False - for idx in range(predicted_audio.size(0)): - predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() - predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] - item_idx = batch_idx * test_dl_batch_size + idx - # Save the predicted audio - log_dir = self.logger.log_dir - audio_dir = os.path.join(log_dir, 'audios') - if not os.path.exists(audio_dir): - os.makedirs(audio_dir) - audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') - audio_durations.append(len(predicted_audio_np) / self.cfg.sample_rate) - sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) - - predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) - predicted_codes_torch = predicted_codes_torch[:, : predicted_codes_lens[idx]] - torch.save( - predicted_codes_torch, - os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt'), - ) - predicted_audio_paths.append(audio_path) - - if not batch_invalid: - with torch.no_grad(): - try: - if self.cfg.get("pref_set_language", "en") == "en": - pred_transcripts = self.eval_asr_model.transcribe( - predicted_audio_paths, batch_size=len(predicted_audio_paths) - )[0] - pred_transcripts = [self.process_text(transcript) for transcript in pred_transcripts] - else: - pred_transcripts = [ - self.transcribe_with_whisper(audio_path, self.cfg.pref_set_language) - for audio_path in predicted_audio_paths - ] - pred_transcripts = [self.process_text(transcript) for transcript in pred_transcripts] - except Exception as e: - assert ( - predicted_audio_lens[idx] < 1000 - ).any(), f"Expected short audio file to be the only cause of ASR errors, but got error with lengths {predicted_audio_lens}" - logging.warning(f"Exception during ASR transcription: {e}") - logging.warning( - "Skipping processing of the batch; generating metrics indicating a WER of 100% and " - "Speaker Similarity of 0.0" - ) - batch_invalid = True - continue # don't break since we want to continue building audio durations list - pred_speaker_embeddings = self.get_speaker_embeddings_from_filepaths(predicted_audio_paths) - gt_speaker_embeddings = self.get_speaker_embeddings_from_filepaths(batch['audio_filepaths']) - - for idx in range(predicted_audio.size(0)): - if not batch_invalid: - item_idx = batch_idx * test_dl_batch_size + idx - pred_transcript = pred_transcripts[idx] - gt_transcript = self.process_text(batch['raw_texts'][idx]) - - cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True) - wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False) - - spk_embedding_pred = pred_speaker_embeddings[idx].cpu().numpy() - spk_embedding_gt = gt_speaker_embeddings[idx].cpu().numpy() - - spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( - np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) - ) - else: - # Create an entry indicating invalid metrics - cer_gt = 1.0 - wer_gt = 1.0 - spk_similarity = 0.0 - pred_transcript = "" # do not change this string; subsequent processing relies on it - gt_transcript = self.process_text(batch['raw_texts'][idx]) - - item_metrics = { - 'cer_gt': float(cer_gt), - 'wer_gt': float(wer_gt), - 'duration': audio_durations[idx], - 'spk_similarity': float(spk_similarity), - 'pred_transcript': pred_transcript, - 'gt_transcript': gt_transcript, - } - - with open( - os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w' - ) as f: - json.dump(item_metrics, f) - diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index 3cfa03284cad..a7158e6393ff 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -26,18 +26,20 @@ from nemo.collections.tts.models import MagpieTTSModel -class MagpieTTSModelPrefDataGen(MagpieTTSModel): - """Small override to save inference metrics, used for datagen in Offline PO""" +class MagpieTTSModelOfflinePODataGen(MagpieTTSModel): + """Small override of MagpieTTSModel for parallel multi-GPU inference and metrics calculation. + This class is used in 'test' mode and leverages trainer.test() for multi-GPU/multi-node inference. + Saves the predicted audio files and logs the CER/WER metrics as individual json files for each audio. + """ + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) if cfg.get('pref_set_language', "en") == "en": self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b") self.eval_asr_model.freeze() - self.eval_asr_model.eval() self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') self.eval_speaker_verification_model.freeze() - self.eval_speaker_verification_model.eval() if cfg.get('load_whisper_model', False): from transformers import WhisperForConditionalGeneration, WhisperProcessor @@ -86,7 +88,7 @@ def test_step(self, batch, batch_idx): try: if self.cfg.get("pref_set_language", "en") == "en": pred_transcripts = self.eval_asr_model.transcribe(predicted_audio_paths, batch_size=len(predicted_audio_paths)) - pred_transcripts = [ process_text_for_cer(transcript) for transcript in pred_transcripts ] + pred_transcripts = [ process_text_for_cer(transcript.text) for transcript in pred_transcripts ] else: pred_transcripts = [] for audio_path in predicted_audio_paths: @@ -140,9 +142,13 @@ def test_step(self, batch, batch_idx): json.dump(item_metrics, f) class MagpieTTSModelOfflinePO(MagpieTTSModel): + """ + MagpieTTS_Model_OfflinePO is a class that extends MagpieTTS_Model to support + offline preference optimization (DPO, IPO, RPO). + Set cfg.model.dpo_loss_type to 'dpo', 'ipo', or 'rpo' to use the corresponding loss. + """ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) - # Copy cfg ref_model_cfg = copy.deepcopy(cfg) with open_dict(ref_model_cfg): ref_model_cfg.train_ds = None @@ -150,8 +156,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) print("Loading reference model from checkpoint") self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) - self.freeze_model(self._reference_model) - self._reference_model.eval() + self._reference_model.freeze() self._reference_model._no_state_dict = True print("Reference model loaded and frozen") @@ -163,7 +168,6 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): del state_dict[key] return state_dict - def _get_batch_logps(self, logits, labels, loss_mask, average_log_prob=False): """Compute the log probabilities of the given labels under the given logits. @@ -182,7 +186,6 @@ def _get_batch_logps(self, logits, labels, loss_mask, average_log_prob=False): else: return (per_token_logps * loss_mask).sum(-1) - # https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py def preference_loss(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, @@ -195,7 +198,7 @@ def preference_loss(self, policy_chosen_logps, loss_type="dpo", reference_free=False): """Compute the DPO loss for a batch of policy and reference model log probabilities. - + Referenced From: https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py Args: policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) @@ -373,6 +376,10 @@ def collect(key): self.validation_step_outputs.clear() class MagpieTTSModelOnlinePO(MagpieTTSModel): + """ + MagpieTTS_Model_OnlinePO is a class that extends MagpieTTS_Model to support + online preference optimization (GRPO). + """ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) # Copy cfg @@ -386,15 +393,13 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) print("Loading reference model from checkpoint") self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) - self.freeze_model(self._reference_model) - self._reference_model.eval() + self._reference_model.freeze() self._reference_model._no_state_dict = True print("Reference model loaded and frozen") if cfg.get('reward_asr_model', "nemo") == "nemo": self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b") self.eval_asr_model.freeze() - self.eval_asr_model.eval() elif cfg.get('reward_asr_model', "nemo") == "whisper": from transformers import WhisperForConditionalGeneration, WhisperProcessor self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") @@ -405,7 +410,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') self.eval_speaker_verification_model.freeze() - self.eval_speaker_verification_model.eval() if cfg.get('load_whisper_model', False): from transformers import WhisperForConditionalGeneration, WhisperProcessor @@ -500,7 +504,7 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): with torch.no_grad(): if self.cfg.get("reward_asr_model", "nemo") == "nemo": pred_transcripts = self.eval_asr_model.transcribe(predicted_audio_paths, batch_size=len(predicted_audio_paths)) - pred_transcripts = [ process_text_for_cer(transcript) for transcript in pred_transcripts ] + pred_transcripts = [ process_text_for_cer(transcript.text) for transcript in pred_transcripts ] elif self.cfg.get("reward_asr_model", "nemo") == "whisper": pred_transcripts = [] for item_idx, audio_path in enumerate(predicted_audio_paths): @@ -804,7 +808,6 @@ def get_speaker_embeddings_from_filepaths(filepaths, speaker_verification_model, return speaker_embeddings def transcribe_with_whisper(audio_filepath, language, whisper_processor, whisper_model, device): - print("Transcribing with whisper", language) speech_array, sampling_rate = librosa.load(audio_filepath, sr=16000) forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language) if language else None inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features diff --git a/scripts/magpietts/README_magpie_po.md b/scripts/magpietts/README_magpie_po.md new file mode 100644 index 000000000000..d6d243e81899 --- /dev/null +++ b/scripts/magpietts/README_magpie_po.md @@ -0,0 +1,178 @@ +### Offline Preference Alignment (DPO/RPO) + +Code: `nemo/collections/tts/models/magpietts_preference_optimization.py` + +Preference Alignment (DPO/RPO) involves the following steps +1) Create a list of text-context pairs for which we will generate preference data. +2) For each text-context pair generate multiple audios from a base T5-TTS checkpoint and calculate metrics (CER/SSIM) for each generation. +3) Create chosen-rejected pairs from the generated audio. +4) Finetune the base T5-TTS checkpoint on the chosen-rejected pairs. + +#### 1. Create text-context pairs +We pair a list of challenging texts with context audios from from Riva and LibriTTS dataset. We add a similar number of regular texts from LibriTTS and Riva (paired with random context audios). We also include examples with text contexts. There are other options for generating text-context pairs. + +``` +python scripts/magpietts/dpo/create_text_contextpairs.py \ + --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ + --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ + --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ + --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ + --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ + --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \ + --nsamples_perpair 6 ; +``` +Each pair is repeated `nsamples_perpair` times which specifies how many samples we want to generate for each pair. The output manifest serves as the input for the next step. + +We can also explore other options for these text-context pairs as well depending on the task. + +#### 2. Generate audios for each text-context pair + +Next, we can generate audios from a base T5-TTS checkpoint using the following command. We pass the `audio_dir` as "/" since our text context pairs contains absolute paths. Model config arguments should be modified accordingly to match the base checkpoint architecture. We can run the below command on cluster to generate audios across multiple nodes. This command saves the generated audios along with the metrics for each generation in the `exp_dir`. Each generated audio file is accompanied with a `.json` file that has the CER/SSIM metrics. + + +``` +python examples/tts/magpietts.py \ +--config-name=magpietts_inference_en \ +mode=test \ +batch_size=64 \ ++init_from_ptl_ckpt="/mountdir/checkpoints/continuouscheckpoints_ks1_ks3/decodercontext_small_282.ckpt" \ +exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282" \ ++test_ds_meta.textcontextpairs.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json" \ ++test_ds_meta.textcontextpairs.audio_dir="/" \ ++test_ds_meta.textcontextpairs.feature_dir="/" \ +model.model_type="decoder_context_tts" \ +model.encoder.kernel_size=3 \ +model.decoder.kernel_size=1 \ +model.context_duration_min=5.0 \ +model.context_duration_max=5.0 \ +model.use_text_conditioning_encoder=true \ +model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.alignment_loss_scale=0.002 \ +model.prior_scaling_factor=null \ +model.load_cached_codes_if_available=false \ +trainer.num_nodes=${SLURM_JOB_NUM_NODES} +``` +#### 3. Create chosen-rejected pairs from the generations + +Next, we go through the generated audio directory and create chosen-rejected pairs. + +``` +python scripts/magpietts/dpo/create_preference_pairs.py \ +--input_manifest /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json \ +--generated_audio_dir /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/audios \ +--group_size 6 \ +--cer_threshold 0.01 \ +--val_size 256 ; +``` + +`cer_threshold=0.01` means that filter out pairs in which the chosen CER > 0.01. + +This command should save train and val manifests for DPO finetuning in the base directory of the generated_audio_dir, that is, `/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/` + +#### 4. DPO Finetuning Command + +Finally, we perform DPO finetuning using the following command: + +``` +python examples/tts/magpietts.py \ +batch_size=4 \ ++init_from_ptl_ckpt="/mountdir/checkpoints/decoder_21_epoch_2.ckpt" \ ++mode="dpo_train" \ +max_epochs=10 \ +exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/TrainingsICML/decodercontext_small_282" \ +exp_manager.checkpoint_callback_params.always_save_nemo=false \ +model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ ++train_ds_meta.dpopreftrain.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_train_manifest.json" \ ++train_ds_meta.dpopreftrain.audio_dir="/" \ ++train_ds_meta.dpopreftrain.feature_dir="/" \ ++val_ds_meta.dpoprefval.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_val_manifest.json" \ ++val_ds_meta.dpoprefval.audio_dir="/" \ ++val_ds_meta.dpoprefval.feature_dir="/" \ ++model.dpo_beta=0.01 \ ++model.dpo_sft_loss_weight=0.0 \ +model.model_type="decoder_context_tts" \ +model.context_duration_min=5.0 \ +model.context_duration_max=5.0 \ +model.use_text_conditioning_encoder=true \ +model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.alignment_loss_scale=0.001 \ +model.prior_scaling_factor=null \ +trainer.val_check_interval=200 \ +trainer.log_every_n_steps=10 \ +model.optim.lr=2e-7 \ +~model.optim.sched \ +trainer.num_nodes=${SLURM_JOB_NUM_NODES} +``` + +Note the following overrides in the above command: + +``` ++mode="dpo_train" \ +model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +``` + +Again, our manifest contain absolute paths so we specify `audio_dir="/"` . + +### Online Preference Optimization (GRPO) + +For online preference optmization, process is much simpler. + +1) Create a list of text-context pairs for which we will generate preference data (just one pair for a text-context not repeated). +We'll use the same process as above, just set `nsamples_perpair 1` in the command. +``` +python scripts/magpietts/dpo/create_text_contextpairs.py \ + --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ + --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ + --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ + --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ + --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ + --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \ + --nsamples_perpair 1 ; +``` + +2. Train using GRPO + +``` +python examples/tts/magpietts.py \ ++mode="onlinepo_train" \ ++init_from_ptl_ckpt="/Data/ICML2025_CKPTS/icml2025_base_checkpoints/decodercontext_small_sp_ks3CorrectWithPrior_onlyphoneme_epoch161.ckpt" \ +max_epochs=1000 \ +exp_manager.exp_dir="/Data/Experiments/NewT5TTSGRPO/Try3NoDropoutBeta0.01_CFG/" \ ++train_ds_meta.grpotrainnomls.manifest_path="/Data/DPOPairsInputDatav2/text_context_pairs_grpo_train_nomls.json" \ ++train_ds_meta.grpotrainnomls.audio_dir="/" \ ++train_ds_meta.grpotrainnomls.feature_dir="/" \ ++val_ds_meta.grpovalnomls.manifest_path="/Data/DPOPairsInputDatav2/text_context_pairs_grpo_val_unseenspeakers_tinysubset.json" \ ++val_ds_meta.grpovalnomls.audio_dir="/" \ ++val_ds_meta.grpovalnomls.feature_dir="/" \ ++model.num_generations_per_item=6 \ ++model.grpo_beta=0.01 \ ++model.reference_free=true \ +model.decoder.p_dropout=0.0 \ +model.encoder.p_dropout=0.0 \ +model.model_type="decoder_context_tts" \ +model.use_text_conditioning_encoder=true \ +model.context_duration_min=5.0 \ +model.context_duration_max=5.0 \ +model.codecmodel_path="/Data/Checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.alignment_loss_scale=0.0 \ +model.prior_scaling_factor=null \ +model.train_ds.dataloader_params.num_workers=0 \ +model.validation_ds.dataloader_params.num_workers=0 \ +exp_manager.checkpoint_callback_params.monitor="val_mean_reward" \ +exp_manager.checkpoint_callback_params.mode="max" \ ++trainer.use_distributed_sampler=False \ ++model.inference_cfg_prob=0.5 \ ++model.inference_cfg_scale=2.5 \ +batch_size=2 \ +model.optim.lr=1e-6 \ +trainer.devices=2 \ +trainer.log_every_n_steps=1 \ +trainer.val_check_interval=50 \ +~model.optim.sched \ +trainer.num_nodes=${SLURM_JOB_NUM_NODES} ; +``` + +Note that setting `+model.reference_free=true` makes the `grpo_beta` param effectively 0 since it does not use the KL regularization loss and saves memory. If using the `grpo_beta > 0` and `+model.reference_free=false`, make sure to set dropout params to 0, `model.decoder.p_dropout=0.0` and +`model.encoder.p_dropout=0.0` for training stabilization. Recommended learning rate is `model.optim.lr=1e-6` or lower. Setting `+model.inference_cfg_prob=0.5` means that for half of the generations will be generated using cfg, so that we optimize for our preferences in both cfg and non cfg inference modes. You may set `+model.inference_cfg_prob=0.0` if we only care about non-cfg inference. \ No newline at end of file From 454cca4eec66b53694e37f52b57c39153a59e271 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:04:18 -0700 Subject: [PATCH 025/113] [magpie][wandb] add loggings of pad ratios for text tokens and audio codes (#66) * [magpie][wandb] add loggings for pad ratios for text tokens and audio codes. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [magpie][wandb] fix pad ratio calculation Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- nemo/collections/tts/models/magpietts.py | 29 ++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 64151a988bbb..b0ff1e0042e4 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -1009,6 +1009,35 @@ def training_step(self, batch, batch_idx): if local_transformer_loss is not None: self.log('train/local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True) + # Log batch info + batch_size, text_token_max_len = batch["text"].shape + text_token_total_num = batch["text_lens"].sum() + batch_info_dict = { + "train/batch_size": batch_size, + "train/text_token_max_len": text_token_max_len, + "train/text_token_total_num_in_batch": text_token_total_num, + "train/text_token_pad_ratio_percent_in_batch": 100 * (1 - text_token_total_num / (batch_size * text_token_max_len)), + } + + if "audio_codes" in batch: + audio_codes_max_len = batch["audio_codes"].shape[-1] + audio_codes_total_num = batch["audio_codes_lens"].sum() + batch_info_dict.update({ + "train/audio_codes_max_len": audio_codes_max_len, + "train/audio_codes_total_num_in_batch": audio_codes_total_num, + "train/audio_codes_pad_ratio_percent_in_batch": 100 * (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)), + }) + else: + audio_samples_max_len = batch["audio"].shape[-1] + audio_samples_total_num = batch["audio_lens"].sum() + batch_info_dict.update({ + "train/audio_samples_max_len": audio_samples_max_len, + "train/audio_samples_total_num_in_batch": audio_samples_total_num, + "train/audio_samples_pad_ratio_percent_in_batch": 100 * (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)), + }) + + self.log_dict(batch_info_dict, on_step=True) + return loss def validation_step(self, batch, batch_idx): From e4e5a757caa2ff10dc3251304023f29d46a53915 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:10:58 -0700 Subject: [PATCH 026/113] [magpie][wandb] add loggings of pad ratios for text tokens and audio codes (#66) * [magpie][wandb] add loggings for pad ratios for text tokens and audio codes. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [magpie][wandb] fix pad ratio calculation Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> From 8ba55061a0ebb161abff4b329e402d5307f4af98 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Mon, 28 Apr 2025 15:07:04 -0700 Subject: [PATCH 027/113] Magpietts 2503 refactor codebook config (#64) * Bugfix: num_audio_tokens_per_codebook Make sure to reserve enough tokens for special uses like EOS/BOS. WARNING: old models will be incompatible with the updated inference YAMLs and will need to override the num_audio_tokens_per_codebook to the value they were trained with. * Rework how number of codes and codebooks are handled (WIP) * Reorder the code a bit for clarity * Refactor codebook configuration * read codec parameters from codec checkpoint; remove corresponding configuration from Magpie YAML files * add mechanism for backward compatibility with older checkpoints: ** If using `infer_and_evaluate.py`, just set the --legacy_codebooks command line flag ** If running training or inference with the Hydra command line, override using the following flags: ``` forced_num_all_tokens_per_codebook: 2048 forced_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -1} # 2047 forced_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -2} # 2046 forced_context_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -4} # 2044 forced_context_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -3} # 2045 ``` * Add README on the codebook reorganization ... and how to load legacy checkpoints. * Cleanup * Cleanup and fixing typos * Cleanup * Cleanup * Clarify the README on the embedding table layout * README cleanup * Rename an attritube for clarity codec_model_downsample_factor --> codec_model_samples_per_frame --- .../README_magpietts_legacy_checkpoints.md | 82 ++++++++++++++ .../tts/conf/magpietts/magpietts_dc_en.yaml | 3 - examples/tts/conf/magpietts/magpietts_en.yaml | 3 - .../magpietts/magpietts_inference_en.yaml | 3 - .../magpietts_inference_multilingual_v1.yaml | 3 - .../magpietts/magpietts_lhotse_dc_en.yaml | 3 - .../magpietts/magpietts_multilingual_v1.yaml | 3 - .../tts/data/text_to_speech_dataset.py | 20 ++-- nemo/collections/tts/models/magpietts.py | 105 +++++++++++------- scripts/magpietts/infer_and_evaluate.py | 33 ++++-- 10 files changed, 179 insertions(+), 79 deletions(-) create mode 100644 examples/tts/README_magpietts_legacy_checkpoints.md diff --git a/examples/tts/README_magpietts_legacy_checkpoints.md b/examples/tts/README_magpietts_legacy_checkpoints.md new file mode 100644 index 000000000000..1d08438fd0aa --- /dev/null +++ b/examples/tts/README_magpietts_legacy_checkpoints.md @@ -0,0 +1,82 @@ +# Background +Magpie-TTS uses special tokens like AUDIO_BOS and AUDIO_EOS for its operation. The indices of these tokens are after the audio codec tokens, at the end of the embedding table. + +In April 2025 we changed the layout of the embedding table in a non-backwards compatible way: + +## Old Layout +With the most common codec configuration (2016 codes), the layout used to look like this: +``` +| Index | Token Description | Comments | +|---------|----------------------|-----------------------------------------------------------------------------------------------------------| +| [0] | Codec Token 0 | | +| [1] | Codec Token 1 | | +| [2] | Codec Token 2 | | +| ... | ... | | +| [2015] | Codec Token 2015 | | +| [2016] | | | +| [2017] | | | +| [2018] | | | +| ... | | | +| [2044] | Context Audio BOS | if model_type == `decoder_context_tts` | +| [2045] | Context Audio EOS | if model_type == `decoder_context_tts` | +| [2046] | Audio BOS | also used for Context Audio BOS if model_type == `multi_encoder_context_tts` or `single_encoder_sv_tts` | +| [2047] | Audio EOS | also used for Context Audio EOS if model_type == `multi_encoder_context_tts` or `single_encoder_sv_tts` | +``` + +## New Layout``` +The new layout for the same codec configuration is: +``` +| Index | Token Description | Comments | +---------------------------------------------| +| [0] | Codec Token 0 | | +| [1] | Codec Token 1 | | +| [2] | Codec Token 2 | | +| ... | ... | | +| [2015] | Codec Token 2015 | | +| [2016] | Audio BOS | | +| [2017] | Audio EOS | | +| [2018] | Context Audio BOS | | +| [2019] | Context Audio EOS | | +| [2020] | MASK token (MaskGit) | | +| [2021] | RESERVED_1 | | +| [2022] | RESERVED_2 | | +| [2023] | RESERVED_3 | | +``` + +# How to Train and Load a New Checkpoint +For new trainings and inference all configuration is automatic: +* The number of codebooks, codec codebooks size, and codec downsampling rate are all read from the codec checkpoint rather than configured in Magpie. +* The embedding table size is automatically set to codec_codebook_size + number_of_special_tokens (currently 2016+8=2024). There is no risk of accidentally stepping on codec tokens since the table sizes gets automatically sized with enough room for the special tokens. + +# How to Load Old Checkpoints +For checkpoints created before the change you can force legacy codebook layout in one of these ways: + +## If using `infer_and_evaluate.py` +Just set the `--legacy_codebooks` command line option. No need to update your YAML file – The script will automatically add the overrides. + +## If using a Hydra command line +You have two options: +### Add these to your command line +``` +# decoder context model ++model.forced_num_all_tokens_per_codebook=2048 +model.forced_audio_eos_id=2047 +model.forced_audio_bos_id=2046 +model.forced_context_audio_eos_id=2045 +model.forced_context_audio_bos_id=2044 + +# multi encoder context and any other model type ++model.forced_num_all_tokens_per_codebook=2048 +model.forced_audio_eos_id=2047 +model.forced_audio_bos_id=2046 +model.forced_context_audio_eos_id=2047 +model.forced_context_audio_bos_id=2046 +``` +# Or, add these overrides to your YAML file +``` +forced_num_all_tokens_per_codebook: 2048 +forced_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -1} # 2047 +forced_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -2} # 2046 + +# Depending on the old model type, the context_audio_bos_id and context_audio_eos_id will be different (choose one of the pairs below) + +# For `multi_encoder_context_tts`, `single_encoder_sv_tts`: +#forced_context_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -1} # 2047 +#forced_context_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -2} # 2046 + +# For `decoder_context_tts` models: +#forced_context_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -3} # 2045 +#forced_context_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -4} # 2044 +``` diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml index 6a6c87ae12c9..517eeff8dcc0 100644 --- a/examples/tts/conf/magpietts/magpietts_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_dc_en.yaml @@ -20,9 +20,6 @@ model: use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. context_duration_min: 3.0 context_duration_max: 5.0 - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) - codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: 0.5 prior_end_step: 12000 diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index 607c4e7a66bf..7dbdc7f5f822 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -22,9 +22,6 @@ model: context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts context_duration_min: 3.0 context_duration_max: 5.0 - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) - codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: 0.5 prior_end_step: 12000 diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_inference_en.yaml index c54221b9a63f..8c984a22748f 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_en.yaml @@ -31,9 +31,6 @@ model: context_duration_max: 8.0 speaker_emb_dim: 192 max_decoder_steps: 500 - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) - codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: null prior_end_step: 0 diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml index cd82580b5ec5..0e87fd677f42 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml @@ -31,9 +31,6 @@ model: context_duration_max: 8.0 speaker_emb_dim: 192 max_decoder_steps: 500 - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) - codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: null prior_end_step: 0 diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index 714bc783e75e..43f572e2d2e8 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -9,9 +9,6 @@ model: use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. context_duration_min: 3.0 context_duration_max: 5.0 - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) - codec_model_downsample_factor: 1_024 codec_model_name: "21fpsCausalDecoder" load_cached_codes_if_available: true prior_scaling_factor: 0.5 diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index e315b7e1fb88..a735ce958e79 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -23,9 +23,6 @@ model: context_duration_min: 3.0 context_duration_max: 8.0 speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2024 # 2016 from codec + 8 reserved tokens (for things like EOS/BOS etc.) - codec_model_downsample_factor: 1024 load_cached_codes_if_available: true prior_scaling_factor: 0.5 prior_end_step: 12000 diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 13054b4b9ae2..00742047537f 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -338,7 +338,7 @@ class MagpieTTSDataset(TextToSpeechDataset): max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration' will be ignored. volume_norm: Whether to apply volume normalization to loaded audio. - codec_model_downsample_factor: Downsample factor of the codec model (Num samples in waveform per codec frame). + codec_model_samples_per_frame: Num samples in waveform per codec frame (codec downsample factor). bos_id: Text BOS token id. eos_id: Text EOS token id. audio_bos_id: Audio BOS token id. @@ -365,7 +365,7 @@ def __init__( min_duration: Optional[float] = None, max_duration: Optional[float] = None, volume_norm: bool = True, - codec_model_downsample_factor: int = None, + codec_model_samples_per_frame: int = None, bos_id: int = None, eos_id: int = None, audio_bos_id: int = None, @@ -403,7 +403,7 @@ def __init__( self.context_audio_bos_id = context_audio_bos_id self.context_audio_eos_id = context_audio_eos_id self.num_audio_codebooks = num_audio_codebooks - self.codec_model_downsample_factor = codec_model_downsample_factor + self.codec_model_samples_per_frame = codec_model_samples_per_frame self.include_align_prior = prior_scaling_factor is not None self.prior_scaling_factor = prior_scaling_factor self.load_cached_codes_if_available = load_cached_codes_if_available @@ -420,8 +420,8 @@ def __init__( self.context_duration_max = context_duration_max def get_num_audio_samples_to_slice(self, duration, sample_rate): - num_codec_frames = int(duration * sample_rate / self.codec_model_downsample_factor) - num_audio_samples = num_codec_frames * self.codec_model_downsample_factor + num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) + num_audio_samples = num_codec_frames * self.codec_model_samples_per_frame return num_audio_samples def __getitem__(self, index): @@ -467,14 +467,14 @@ def __getitem__(self, index): # Pad audio to be multiple of downsample factor audio = torch.nn.functional.pad( audio, - (0, self.codec_model_downsample_factor - (audio.shape[0] % self.codec_model_downsample_factor)), + (0, self.codec_model_samples_per_frame - (audio.shape[0] % self.codec_model_samples_per_frame)), value=0, ) audio_len = audio.shape[0] example['audio_filepath'] = data.manifest_entry['audio_filepath'] example['audio'] = audio example['audio_len'] = audio_len - spec_len = int(audio_len / self.codec_model_downsample_factor) + 1 # +1 for EOS + spec_len = int(audio_len / self.codec_model_samples_per_frame) + 1 # +1 for EOS if self.load_cached_codes_if_available and 'context_audio_codes_path' in data.manifest_entry: context_audio_codes_path = data.manifest_entry['context_audio_codes_path'] @@ -482,7 +482,7 @@ def __getitem__(self, index): # Sample random duration between self.context_duration_min and self.context_duration_max _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) _num_frames_to_slice = int( - _context_duration_to_slice * self.sample_rate / self.codec_model_downsample_factor + _context_duration_to_slice * self.sample_rate / self.codec_model_samples_per_frame ) if _num_frames_to_slice < context_audio_codes.shape[1]: start_idx = random.randint(0, context_audio_codes.shape[1] - _num_frames_to_slice) @@ -545,7 +545,7 @@ def __getitem__(self, index): example['context_audio_codes_len'] = context_audio_codes_len else: # @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes - context_audio = torch.zeros(self.codec_model_downsample_factor, dtype=torch.float32) + context_audio = torch.zeros(self.codec_model_samples_per_frame, dtype=torch.float32) context_audio_len = context_audio.shape[0] example['context_audio'] = context_audio example['context_audio_len'] = context_audio_len @@ -586,7 +586,7 @@ def __getitem__(self, index): example['has_text_context'] = False if self.pad_context_text_to_max_duration: _required_len = ( - int(self.context_duration_max * self.sample_rate / self.codec_model_downsample_factor) + 2 + int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 ) # +2 for BOS and EOS if len(context_tokens) < _required_len: _pad_id = self.text_conditioning_tokenizer.pad_token_id diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index b0ff1e0042e4..375f39592a50 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -17,6 +17,7 @@ import string import time from typing import List +from enum import Enum, verify, CONTINUOUS, UNIQUE import librosa import numpy as np @@ -97,6 +98,37 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_devices + # load codec + codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False) + # del codec discriminator to free memory + del codec_model.discriminator + + # Reserve special tokens (appended at the end of the codebook after the actual audio codec tokens) + # (the actual index is this value plus the number of codec tokens - do not use the Enum directy) + @verify(CONTINUOUS, UNIQUE) + class SpecialAudioToken(Enum): + AUDIO_BOS = 0 + AUDIO_EOS = 1 + AUDIO_CONTEXT_BOS = 2 + AUDIO_CONTEXT_EOS = 3 + MASK_TOKEN = 4 + # Reserved so that if we need to add more special tokens in the future the codebook size will remain the same + RESERVED_1 = 5 + RESERVED_2 = 6 + RESERVED_3 = 7 + + # Set up codebook configuration + self.num_audio_codebooks = codec_model.num_codebooks + self.codec_model_samples_per_frame = codec_model.samples_per_frame + # Our codebooks start with actual audio codec tokens, followed by special tokens. + # The `forced_*` options are for backward compatibility for models trained with older code. + num_audio_tokens = codec_model.codebook_size + self.audio_bos_id = cfg.get('forced_audio_bos_id',num_audio_tokens + SpecialAudioToken.AUDIO_BOS.value) + self.audio_eos_id = cfg.get('forced_audio_eos_id', num_audio_tokens + SpecialAudioToken.AUDIO_EOS.value) + self.context_audio_bos_id = cfg.get('forced_context_audio_bos_id', num_audio_tokens + SpecialAudioToken.AUDIO_CONTEXT_BOS.value) + self.context_audio_eos_id = cfg.get('forced_context_audio_eos_id', num_audio_tokens + SpecialAudioToken.AUDIO_CONTEXT_EOS.value) + self.num_all_tokens_per_codebook = cfg.get('forced_num_all_tokens_per_codebook',num_audio_tokens + len(SpecialAudioToken)) + # Setup tokenizer if hasattr(cfg, 'text_tokenizer'): # For backward compatibility for English-only models @@ -119,19 +151,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.bos_id = num_tokens - 2 self.eos_id = num_tokens - 1 - self.audio_bos_id = cfg.num_audio_tokens_per_codebook - 2 - self.audio_eos_id = cfg.num_audio_tokens_per_codebook - 1 - self.context_audio_bos_id = cfg.num_audio_tokens_per_codebook - 2 # For backward compatibility - self.context_audio_eos_id = cfg.num_audio_tokens_per_codebook - 1 # For backward compatibility - self.model_type = cfg.get('model_type', 'single_encoder_sv_tts') - if self.model_type == 'decoder_context_tts': - self.context_audio_bos_id = ( - cfg.num_audio_tokens_per_codebook - 4 - ) # Changing these to make them different from target audio bos and eos - self.context_audio_eos_id = cfg.num_audio_tokens_per_codebook - 3 - self.pad_context_text_to_max_duration = self.model_type == 'decoder_context_tts' self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False) @@ -139,10 +160,14 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if self.use_text_conditioning_encoder: self.context_text_embedding = nn.Embedding(self.text_conditioning_tokenizer.vocab_size, cfg.embedding_dim) + + # This needs to happen after super().__init__() + self._codec_model = codec_model + self._codec_model.freeze() #Lightning does requires_grad = False and self.eval() audio_embeddings = [] - for _ in range(cfg.num_audio_codebooks): - audio_embeddings.append(nn.Embedding(cfg.num_audio_tokens_per_codebook, cfg.embedding_dim)) + for _ in range(self.num_audio_codebooks): + audio_embeddings.append(nn.Embedding(self.num_all_tokens_per_codebook, cfg.embedding_dim)) self.audio_embeddings = nn.ModuleList(audio_embeddings) if self.model_type != 'decoder_pretrain_synthesizer': @@ -151,7 +176,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) - self.final_proj = nn.Linear(cfg.decoder.d_model, cfg.num_audio_codebooks * cfg.num_audio_tokens_per_codebook) + self.final_proj = nn.Linear(cfg.decoder.d_model, self.num_audio_codebooks * self.num_all_tokens_per_codebook) if cfg.get('use_local_transformer', False): local_transformer_hidden_dim = cfg.get('local_transformer_hidden_dim', 256) if local_transformer_hidden_dim != cfg.decoder.d_model: @@ -165,13 +190,13 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): sa_n_heads=self.cfg.get('local_transformer_n_heads', 1), kernel_size=1, is_causal=True, - max_length_causal_mask=cfg.num_audio_codebooks+2, + max_length_causal_mask=self.num_audio_codebooks+2, use_learnable_pos_emb=True, ) local_transformer_out_projections = [] - for _ in range(cfg.num_audio_codebooks): + for _ in range(self.num_audio_codebooks): # Have a separate projection layer for each codebook, to distinguish between them - local_transformer_out_projections.append(nn.Linear(local_transformer_hidden_dim, cfg.num_audio_tokens_per_codebook)) + local_transformer_out_projections.append(nn.Linear(local_transformer_hidden_dim, self.num_all_tokens_per_codebook)) self.local_transformer_out_projections = nn.ModuleList(local_transformer_out_projections) if cfg.get('use_alignment_encoder', False): @@ -182,12 +207,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): temperature=15.0, ) - codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False) - # del codec discriminator to free memory - del codec_model.discriminator - self._codec_model = codec_model - self._codec_model.freeze() #Lightning does requires_grad = False and self.eval() - if self.model_type == 'single_encoder_sv_tts': self._speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( model_name='titanet_large' @@ -334,13 +353,13 @@ def compute_local_transformer_logits(self, dec_out, audio_codes_target): for codebook_num in range(audio_codes_target.size(1)): # Using a separate projection layer for each codebook (to distinguish between them) # Checked the time - this loop is not taking much time (compared to the local transformer forward pass) - codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, codebook_num, :]) # (B*T', num_audio_tokens_per_codebook) + codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, codebook_num, :]) # (B*T', num_all_tokens_per_codebook) all_code_logits.append(codebook_logits) - all_code_logits = torch.cat(all_code_logits, dim=1) # (B*T', num_codebooks * num_audio_tokens_per_codebook) + all_code_logits = torch.cat(all_code_logits, dim=1) # (B*T', num_codebooks * num_all_tokens_per_codebook) all_code_logits = all_code_logits.view( audio_codes_target.size(0), audio_codes_target.size(2), -1 - ) # (B, T', C * num_audio_tokens_per_codebook) + ) # (B, T', C * num_all_tokens_per_codebook) return all_code_logits @@ -351,8 +370,8 @@ def compute_loss(self, logits, audio_codes, audio_codes_lens): loss_mask = get_mask_from_lengths(audio_codes_lens) total_codebook_loss = None for codebook in range(audio_codes.size(1)): - si = codebook * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook + si = codebook * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook) codebook_targets = audio_codes[:, codebook] # (B, T') codebook_loss = self.cross_entropy_loss( @@ -385,9 +404,9 @@ def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): # all_code_logits: (B, T', num_codebooks * num_tokens_per_codebook) # audio_codes_lens: (B,) all_preds = [] - for idx in range(self.cfg.num_audio_codebooks): - si = idx * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook + for idx in range(self.num_audio_codebooks): + si = idx * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook codebook_logits = all_code_logits[:, :, si:ei] codebook_probs = torch.softmax(codebook_logits, dim=-1) # (B, T', num_tokens_per_codebook) # argmax to get the tokens @@ -406,10 +425,10 @@ def sample_codes_from_local_transformer(self, dec_output, temperature=0.7, topk= dec_output = dec_output.unsqueeze(1) # (B, 1, E) local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) all_preds = [] - for codebook_num in range(self.cfg.num_audio_codebooks): + for codebook_num in range(self.num_audio_codebooks): _mask = torch.ones( local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device) local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, T, 128) - codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, -1, :]) # (B, num_audio_tokens_per_codebook) + codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, -1, :]) # (B, num_all_tokens_per_codebook) if use_cfg: actual_batch_size = codebook_logits.size(0) // 2 conditional_logits = codebook_logits[:actual_batch_size] @@ -446,9 +465,9 @@ def sample_codes_from_local_transformer(self, dec_output, temperature=0.7, topk= def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={}): # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep all_preds = [] - for idx in range(self.cfg.num_audio_codebooks): - si = idx * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook + for idx in range(self.num_audio_codebooks): + si = idx * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook codebook_logits = all_code_logits_t[:, si:ei] # (B, num_tokens_per_codebook) for item_idx in unfinished_items: codebook_logits[item_idx, self.audio_eos_id] = float('-inf') @@ -874,8 +893,8 @@ def process_batch(self, batch, mode="train"): and torch.rand(1).item() < 0.5 ): # For some batches (half of them), replace decoder_input_dropout_prob of the timesteps with random tokens - max_codebook_val = self.cfg.get('dec_random_input_max', self.cfg.num_audio_tokens_per_codebook) - # @pneekhara: Keeping dec_random_input_max configurable since num_audio_tokens_per_codebook usually has padding tokens + max_codebook_val = self.cfg.get('dec_random_input_max', self.num_all_tokens_per_codebook) + # @pneekhara: Keeping dec_random_input_max configurable since num_all_tokens_per_codebook usually has padding tokens # which can cause errors when doing codes_to_audio for audio_codes_input. We are not currently calling codes_to_audio on # audio_codes_input so should not matter if we don't supply dec_random_input_max. random_audio_tokens = torch.randint( @@ -1245,7 +1264,7 @@ def infer_batch( context_tensors = self.prepare_context_tensors(batch) text = context_tensors['text'] audio_codes_bos = torch.full( - (text.size(0), self.cfg.num_audio_codebooks, 1), self.audio_bos_id, device=text.device + (text.size(0), self.num_audio_codebooks, 1), self.audio_bos_id, device=text.device ).long() audio_codes_lens = torch.full((text.size(0),), 1, device=text.device).long() audio_codes_input = audio_codes_bos @@ -1518,8 +1537,8 @@ def get_dataset(self, dataset_cfg, dataset_type): audio_eos_id=self.audio_eos_id, context_audio_bos_id=self.context_audio_bos_id, context_audio_eos_id=self.context_audio_eos_id, - num_audio_codebooks=self.cfg.num_audio_codebooks, - codec_model_downsample_factor=self.cfg.codec_model_downsample_factor, + num_audio_codebooks=self.num_audio_codebooks, + codec_model_samples_per_frame=self.codec_model_samples_per_frame, prior_scaling_factor=self.cfg.prior_scaling_factor, load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, dataset_type=dataset_type, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) @@ -1540,13 +1559,13 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D dataset = MagpieTTSLhotseDataset( sample_rate=self.cfg.sample_rate, volume_norm=dataset_cfg.volume_norm, - codec_model_downsample_factor=self.cfg.codec_model_downsample_factor, + codec_model_samples_per_frame=self.codec_model_samples_per_framed, codec_model_name=self.cfg.codec_model_name, audio_bos_id=self.audio_bos_id, audio_eos_id=self.audio_eos_id, context_audio_bos_id=self.context_audio_bos_id, context_audio_eos_id=self.context_audio_eos_id, - num_audio_codebooks=self.cfg.num_audio_codebooks, + num_audio_codebooks=self.num_audio_codebooks, prior_scaling_factor=self.cfg.prior_scaling_factor, load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index d35019562350..2ca3f07a7e03 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -31,7 +31,7 @@ def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0 metrics[key] = "{:.4f} +/- {:.4f}".format(mean, confidence_interval) return metrics -def update_config(model_cfg, codecmodel_path): +def update_config(model_cfg, codecmodel_path, legacy_codebooks=False): ''' helper function to rename older yamls from t5 to magpie ''' model_cfg.codecmodel_path = codecmodel_path if hasattr(model_cfg, 'text_tokenizer'): @@ -47,6 +47,19 @@ def update_config(model_cfg, codecmodel_path): if "t5_decoder" in model_cfg: model_cfg.decoder = model_cfg.t5_decoder del model_cfg.t5_decoder + if legacy_codebooks: + print("WARNING: Using legacy codebook indices for backward compatibility. Should only be used with old checkpoints.") + num_audio_tokens_per_codebook = model_cfg.num_audio_tokens_per_codebook + model_cfg.forced_num_all_tokens_per_codebook = num_audio_tokens_per_codebook + model_cfg.forced_audio_eos_id = num_audio_tokens_per_codebook - 1 + model_cfg.forced_audio_bos_id = num_audio_tokens_per_codebook - 2 + if model_cfg.model_type == 'decoder_context_tts': + model_cfg.forced_context_audio_eos_id = num_audio_tokens_per_codebook - 3 + model_cfg.forced_context_audio_bos_id = num_audio_tokens_per_codebook - 4 + else: + model_cfg.forced_context_audio_eos_id = num_audio_tokens_per_codebook - 1 + model_cfg.forced_context_audio_bos_id = num_audio_tokens_per_codebook - 2 + return model_cfg def run_inference( @@ -71,7 +84,8 @@ def run_inference( apply_prior_to_layers=None, start_prior_after_n_audio_steps=10, confidence_level=0.95, - use_local_transformer=False + use_local_transformer=False, + legacy_codebooks=False ): # Load model if hparams_file is not None: @@ -80,7 +94,7 @@ def run_inference( model_cfg = model_cfg.cfg with open_dict(model_cfg): - model_cfg = update_config(model_cfg, codecmodel_path) + model_cfg = update_config(model_cfg, codecmodel_path, legacy_codebooks) model = MagpieTTSModel(cfg=model_cfg) model.use_kv_cache_for_inference = True @@ -93,7 +107,7 @@ def run_inference( elif nemo_file is not None: model_cfg = MagpieTTSModel.restore_from(nemo_file, return_config=True) with open_dict(model_cfg): - model_cfg = update_config(model_cfg, codecmodel_path) + model_cfg = update_config(model_cfg, codecmodel_path, legacy_codebooks) model = MagpieTTSModel.restore_from(nemo_file, override_config_path=model_cfg) model.use_kv_cache_for_inference = True checkpoint_name = nemo_file.split("/")[-1].split(".nemo")[0] @@ -145,14 +159,14 @@ def run_inference( sample_rate=model_cfg.sample_rate, min_duration=0.5, max_duration=20, - codec_model_downsample_factor=model_cfg.codec_model_downsample_factor, + codec_model_samples_per_frame=model.codec_model_samples_per_frame, bos_id=model.bos_id, eos_id=model.eos_id, context_audio_bos_id=model.context_audio_bos_id, context_audio_eos_id=model.context_audio_eos_id, audio_bos_id=model.audio_bos_id, audio_eos_id=model.audio_eos_id, - num_audio_codebooks=model_cfg.num_audio_codebooks, + num_audio_codebooks=model.num_audio_codebooks, prior_scaling_factor=None, load_cached_codes_if_available=False, dataset_type='test', @@ -302,6 +316,7 @@ def main(): parser.add_argument('--asr_model_name', type=str, default="stt_en_conformer_transducer_large") # stt_en_conformer_transducer_large, nvidia/parakeet-ctc-0.6b parser.add_argument('--num_repeats', type=int, default=1) parser.add_argument('--confidence_level', type=float, default=0.95) + parser.add_argument('--legacy_codebooks', action='store_true') args = parser.parse_args() estimate_alignment_from_layers = None @@ -340,7 +355,8 @@ def main(): apply_prior_to_layers=apply_prior_to_layers, start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, - use_local_transformer=args.use_local_transformer + use_local_transformer=args.use_local_transformer, + legacy_codebooks=args.legacy_codebooks ) return elif (args.nemo_file is not None): @@ -368,7 +384,8 @@ def main(): apply_prior_to_layers=apply_prior_to_layers, start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, - use_local_transformer=args.use_local_transformer + use_local_transformer=args.use_local_transformer, + legacy_codebooks=args.legacy_codebooks ) else: BASE_EXP_DIR = args.base_exp_dir From cae27aa31e0379fc7fecf7547ec02d91d8c94ae4 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Tue, 29 Apr 2025 06:26:57 -0700 Subject: [PATCH 028/113] [magpie][wandb][bugfix] ensure consistent validation step for audio and image on the sliding bar instead of incrementing by 1. (#61) * [magpie][wandb][bugfix] ensure consistent validation step for audio and image on the sliding bar instead of incrementing by 1. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [magpietts][loggers] support logging metrics using multiple loggers enabled in exp_manager. * [magpietts][lhotse_dataset] remove useless imports and functions. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- .../tts/data/text_to_speech_dataset_lhotse.py | 12 - nemo/collections/tts/models/magpietts.py | 283 ++++++++++-------- 2 files changed, 157 insertions(+), 138 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index 7daa0af57e61..ab8503da285f 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -21,11 +21,9 @@ from hydra.utils import instantiate from lhotse import CutSet from lhotse.dataset.collation import collate_matrices, collate_vectors -from megatron.core import parallel_state from omegaconf import DictConfig from transformers import AutoTokenizer, T5Tokenizer -from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer from nemo.collections.tts.parts.utils.tts_dataset_utils import ( beta_binomial_prior_distribution, @@ -74,16 +72,6 @@ def check_speaker_format(item: str): return bool(re.match(pattern, item)) -def build_lhotse_dataloader(dataset, data_cfg, is_eval=False): - """Buld dataloader given an input dataset.""" - return get_lhotse_dataloader_from_config( - data_cfg, - global_rank=parallel_state.get_data_parallel_rank(), - world_size=parallel_state.get_data_parallel_world_size(), - dataset=dataset, - ) - - class MagpieTTSLhotseDataset(torch.utils.data.Dataset): """ A PyTorch Dataset for loading and processing Text-to-Speech data for diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 375f39592a50..2eea2fe9532e 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -23,6 +23,7 @@ import numpy as np import soundfile as sf import torch +import wandb from hydra.utils import instantiate from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger @@ -48,12 +49,6 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging -HAVE_WANDB = True -try: - import wandb -except ModuleNotFoundError: - HAVE_WANDB = False - def worker_init_fn(worker_id): # For mp.set_start_method("spawn", force=True) @@ -489,37 +484,40 @@ def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80, def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens, prefix="", dec_context_size=0): # attention_prob_matrix List of (B, C, audio_timesteps, text_timesteps) + wandb_images_log = {} + with torch.no_grad(): attention_prob_matrix = torch.cat(attention_prob_matrix, dim=1) # (B, C, audio_timesteps, text_timesteps) attention_prob_matrix_mean = attention_prob_matrix.mean(dim=1) # (B, audio_timesteps, text_timesteps) - images = list() - for idx in range(min(3, attention_prob_matrix_mean.size(0))): - item_attn_matrix = attention_prob_matrix_mean[idx][ - dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx] - ] - item_attn_matrix = item_attn_matrix.detach().cpu().numpy() - images.append(plot_alignment_to_numpy(item_attn_matrix.T)) - - if isinstance(self.logger, WandbLogger) and HAVE_WANDB: - self.logger.log_image( - key=f"Image/{prefix}/attention_matrix", - images=images, - step=self.global_step, - caption=[f"Example_{idx}" for idx in range(len(images))], - ) - elif isinstance(self.logger, TensorBoardLogger): - for idx, img in enumerate(images): - self.logger.experiment.add_image( - f'{prefix}/attention_matrix/Example_{idx}', - img, - global_step=self.global_step, - dataformats="HWC", - ) - else: - ValueError(f"Invalid logger: {self.logger}") + for logger in self.loggers: + is_wandb = isinstance(logger, WandbLogger) + is_tb = isinstance(logger, TensorBoardLogger) + if not is_wandb and not is_tb: + raise ValueError(f"Invalid logger type for image logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported.") + + wandb_images_log[f"Image/{prefix}/attention_matrix"] = list() + for idx in range(min(3, attention_prob_matrix_mean.size(0))): + item_attn_matrix = attention_prob_matrix_mean[idx][ + dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx] + ] + item_attn_matrix = item_attn_matrix.detach().cpu().numpy() + img_np = plot_alignment_to_numpy(item_attn_matrix.T) + + if is_wandb: + wandb_images_log[f"Image/{prefix}/attention_matrix"].append(wandb.Image(img_np, caption=f"Example_{idx}")) + + if is_tb: + logger.experiment.add_image( + f'{prefix}/attention_matrix/Example_{idx}', + img_np, + global_step=self.global_step, + dataformats="HWC", + ) + + return wandb_images_log - def log_train_val_audio_example( + def log_val_audio_example( self, logits, target_audio_codes, @@ -527,6 +525,8 @@ def log_train_val_audio_example( context_audio_codes=None, context_audio_codes_lens=None, ): + wandb_audio_log = {} + pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens_target) pred_audio, pred_audio_lens = self.codes_to_audio(pred_audio_codes, audio_codes_lens_target) target_audio, target_audio_lens = self.codes_to_audio(target_audio_codes, audio_codes_lens_target) @@ -536,54 +536,51 @@ def log_train_val_audio_example( # > 3 ensures, it is a valid context audio tensor (and not dummy tensor used in text context) context_audio, context_audio_lens = self.codes_to_audio(context_audio_codes, context_audio_codes_lens) - for idx in range(min(3, pred_audio.size(0))): - pred_audio_np = pred_audio[idx].float().detach().cpu().numpy() - target_audio_np = target_audio[idx].float().detach().cpu().numpy() - pred_audio_np = pred_audio_np[: pred_audio_lens[idx]] - target_audio_np = target_audio_np[: target_audio_lens[idx]] - context_audio_np = None - if context_audio is not None: - context_audio_np = context_audio[idx].float().detach().cpu().numpy() - context_audio_np = context_audio_np[: context_audio_lens[idx]] - - if isinstance(self.logger, WandbLogger) and HAVE_WANDB: - if context_audio_np is not None: - audios_np = [context_audio_np] - captions = ["context"] - else: - audios_np = list() - captions = list() - audios_np = audios_np + [pred_audio_np, target_audio_np] - captions = captions + ["prediction", "target"] - self.logger.log_audio( - key=f"Audio/Example_{idx}", - audios=audios_np, - step=self.global_step, - sample_rate=[self.cfg.sample_rate] * len(audios_np), - caption=captions, - ) - elif isinstance(self.logger, TensorBoardLogger): - if context_audio_np is not None: - self.logger.experiment.add_audio( - f'Example_{idx}/context', - context_audio_np, + for logger in self.loggers: + is_wandb = isinstance(logger, WandbLogger) + is_tb = isinstance(logger, TensorBoardLogger) + if not is_wandb and not is_tb: + raise ValueError(f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported.") + + for idx in range(min(3, pred_audio.size(0))): + pred_audio_np = pred_audio[idx].float().detach().cpu().numpy() + target_audio_np = target_audio[idx].float().detach().cpu().numpy() + pred_audio_np = pred_audio_np[: pred_audio_lens[idx]] + target_audio_np = target_audio_np[: target_audio_lens[idx]] + context_audio_np = None + if context_audio is not None: + context_audio_np = context_audio[idx].float().detach().cpu().numpy() + context_audio_np = context_audio_np[: context_audio_lens[idx]] + + if is_wandb: + wandb_audio_log[f"Audio/Example_{idx}"] = list() + if context_audio_np is not None: + wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(context_audio_np, sample_rate=self.cfg.sample_rate, caption="context")) + wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(pred_audio_np, sample_rate=self.cfg.sample_rate, caption="prediction")) + wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(target_audio_np, sample_rate=self.cfg.sample_rate, caption="target")) + + if is_tb: + if context_audio_np is not None: + logger.experiment.add_audio( + f'Example_{idx}/context', + context_audio_np, + global_step=self.global_step, + sample_rate=self.cfg.sample_rate, + ) + logger.experiment.add_audio( + f'Example_{idx}/prediction', + pred_audio_np, + global_step=self.global_step, + sample_rate=self.cfg.sample_rate, + ) + logger.experiment.add_audio( + f'Example_{idx}/target', + target_audio_np, global_step=self.global_step, sample_rate=self.cfg.sample_rate, ) - self.logger.experiment.add_audio( - f'Example_{idx}/prediction', - pred_audio_np, - global_step=self.global_step, - sample_rate=self.cfg.sample_rate, - ) - self.logger.experiment.add_audio( - f'Example_{idx}/target', - target_audio_np, - global_step=self.global_step, - sample_rate=self.cfg.sample_rate, - ) - else: - ValueError(f"Invalid logger: {self.logger}") + + return wandb_audio_log def scale_prior(self, prior, global_step): if prior is None: @@ -1081,9 +1078,17 @@ def validation_step(self, batch, batch_idx): aligner_encoder_loss = torch.tensor(0.0, device=loss.device) if batch_idx == 0 and self.global_rank == 0: - self.log_train_val_audio_example( - logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens - ) # Currently, only logs parallel prediction (logits). No local transformer results + # Prepare dictionary for aggregated wandb logging + wandb_log_dict = {} + + # Get audio data for logging + wandb_log_dict.update( + self.log_val_audio_example( + logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens + ) + ) + + # Get attention image data for logging if ( self.model_type != 'decoder_pretrain_synthesizer' and len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1 @@ -1091,33 +1096,53 @@ def validation_step(self, batch, batch_idx): # cross_attn_probabilities only returned when not using flash attention ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers) cross_attention_probs = [attn['cross_attn_probabilities'][0] for layer_idx, attn in enumerate(attn_info) if layer_idx in ctc_prior_layer_ids] - self.log_attention_probs( - cross_attention_probs, - audio_codes_lens_target, - text_lens, - prefix="val", - dec_context_size=dec_context_size, + wandb_log_dict.update( + self.log_attention_probs( + cross_attention_probs, + audio_codes_lens_target, + text_lens, + prefix="val", + dec_context_size=dec_context_size, + ) ) + for layer_idx in self.transcript_decoder_layers: cross_attention_probs = [ attn_info[layer_idx]['cross_attn_probabilities'][0] ] - self.log_attention_probs(cross_attention_probs, audio_codes_lens_target, text_lens, prefix=f"val/layer_{layer_idx}", dec_context_size=dec_context_size) + wandb_log_dict.update( + self.log_attention_probs( + cross_attention_probs, + audio_codes_lens_target, + text_lens, + prefix=f"val/layer_{layer_idx}", + dec_context_size=dec_context_size + ) + ) if batch_output['aligner_attn_soft'] is not None: - self.log_attention_probs( - [batch_output['aligner_attn_soft']], - audio_codes_lens_target, - text_lens, - prefix=f"val/aligner_encoder_attn", + wandb_log_dict.update( + self.log_attention_probs( + [batch_output['aligner_attn_soft']], + audio_codes_lens_target, + text_lens, + prefix=f"val/aligner_encoder_attn", + ) ) if batch_output['aligner_attn_hard'] is not None: - self.log_attention_probs( - [batch_output['aligner_attn_hard'].unsqueeze(1)], - audio_codes_lens_target, - text_lens, - prefix=f"val/aligner_encoder_attn_hard", + wandb_log_dict.update( + self.log_attention_probs( + [batch_output['aligner_attn_hard'].unsqueeze(1)], + audio_codes_lens_target, + text_lens, + prefix=f"val/aligner_encoder_attn_hard", + ) ) + # Perform single wandb log call if wandb is active and there is data + for logger in self.loggers: + if isinstance(logger, WandbLogger) and wandb_log_dict: + logger.experiment.log(wandb_log_dict) + local_transformer_loss = batch_output['local_transformer_loss'] val_output = { 'val_loss': loss, @@ -1483,35 +1508,41 @@ def test_step(self, batch, batch_idx): use_cfg=use_cfg, cfg_scale=cfg_scale, ) - for idx in range(predicted_audio.size(0)): - predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() - predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] - item_idx = batch_idx * test_dl_batch_size + idx - - if isinstance(self.logger, WandbLogger) and HAVE_WANDB: - log_dict = { - f"test/predicted_audio": wandb.Audio( - predicted_audio_np, sample_rate=self.cfg.sample_rate, caption=f"Predicted Audio" - ), - } - self.logger.experiment.log(log_dict, step=item_idx) - elif isinstance(self.logger, TensorBoardLogger): - self.logger.experiment.add_audio( - 'test/predicted_audio', - predicted_audio_np, - global_step=item_idx, - sample_rate=self.cfg.sample_rate, - ) - else: - ValueError(f"Invalid logger: {self.logger}") - - # Save the predicted audio - log_dir = self.logger.log_dir - audio_dir = os.path.join(log_dir, 'audios') - if not os.path.exists(audio_dir): - os.makedirs(audio_dir) - audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') - sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) + + for logger in self.loggers: + is_wandb = isinstance(logger, WandbLogger) + is_tb = isinstance(logger, TensorBoardLogger) + if not is_wandb and not is_tb: + raise ValueError(f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported.") + + for idx in range(predicted_audio.size(0)): + predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() + predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] + item_idx = batch_idx * test_dl_batch_size + idx + + if is_wandb: + log_dict = { + f"test/predicted_audio": wandb.Audio( + predicted_audio_np, sample_rate=self.cfg.sample_rate, caption=f"Predicted Audio" + ), + } + logger.experiment.log(log_dict, step=item_idx) + + if is_tb: + logger.experiment.add_audio( + 'test/predicted_audio', + predicted_audio_np, + global_step=item_idx, + sample_rate=self.cfg.sample_rate, + ) + + # Save the predicted audio + log_dir = logger.log_dir + audio_dir = os.path.join(log_dir, 'audios') + if not os.path.exists(audio_dir): + os.makedirs(audio_dir) + audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') + sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) def on_validation_epoch_end(self): collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean() From a100ad1242d59b8550041b0b6c8fe4cd27ec8d5f Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 29 Apr 2025 06:27:21 -0700 Subject: [PATCH 029/113] Codebook layout update: bugfix and README refinements (#68) * Refine the README on codebook layout updates * Typo fix * Bugfix: wire in the `legacy_codebooks` flag in a missing place --- .../tts/README_magpietts_legacy_checkpoints.md | 15 ++++++++++++++- scripts/magpietts/infer_and_evaluate.py | 3 ++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/tts/README_magpietts_legacy_checkpoints.md b/examples/tts/README_magpietts_legacy_checkpoints.md index 1d08438fd0aa..a6a048ea39ab 100644 --- a/examples/tts/README_magpietts_legacy_checkpoints.md +++ b/examples/tts/README_magpietts_legacy_checkpoints.md @@ -3,7 +3,7 @@ Magpie-TTS uses special tokens like AUDIO_BOS and AUDIO_EOS for its operation. T In April 2025 we changed the layout of the embedding table in a non-backwards compatible way: -## Old Layout +## Old Layout (until April 16) With the most common codec configuration (2016 codes), the layout used to look like this: ``` | Index | Token Description | Comments | @@ -55,6 +55,8 @@ For checkpoints created before the change you can force legacy codebook layout i Just set the `--legacy_codebooks` command line option. No need to update your YAML file – The script will automatically add the overrides. ## If using a Hydra command line +This scenario would happen when either finetuning with an old checkpoint or doing data generation with an old checkpoint. + You have two options: ### Add these to your command line ``` @@ -80,3 +82,14 @@ forced_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -2} #forced_context_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -3} # 2045 #forced_context_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -4} # 2044 ``` + +# Additional Details +Over the last few weeks we have gone through a few embedding table layouts. When using an old checkpoint it's important to know which layout your checkpoint was trained with and configuring the system accordingly. + +* Layout 1: used until April 16 (described in the table above). Add `--legacy-codebooks` to the `infer_and_evaluate.py` command line to inference using this layout. + +* Layout 2: after the [config changes](https://github.com/blisc/NeMo/commit/7e2cdca74a866ecefdbe01c0076ad9b5d140ac61): 2018 tokens with special tokens at the end 2017, 2016, 2015, 2014 (the last two being overwrites of codec tokens). This is an invalid layout and these checkpoints should not be used. + +* Layout 3: after the [bugfix](https://github.com/blisc/NeMo/commit/23e299a0bd14b666543b4bbcc7783f783acb0bd3) but before the [refactoring](https://github.com/blisc/NeMo/commit/8ba55061a0ebb161abff4b329e402d5307f4af98): 2024 tokens with special tokens at the end (2023, 2022, 2021, 2020). There are no automatic options provided for using this layout but it can be manually configured by updating the `hparams.yaml` file with the `forced_*` options. Set `forced_num_all_tokens_per_codebook` to `2024` and set the rest of the overrides as defined under section `# Or, add these overrides to your YAML file` above. + +* Layout 4: The new layout, [from this commit onwards](https://github.com/blisc/NeMo/commit/8ba55061a0ebb161abff4b329e402d5307f4af98): 2024 tokens but with special tokens immediately after codec tokens (2016, 2017, 2018, 2019). Training and inference with the latest version of the code automatically use this layout. diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 2ca3f07a7e03..3a442d3b12a9 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -446,7 +446,8 @@ def main(): apply_prior_to_layers=apply_prior_to_layers, start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, - use_local_transformer=args.use_local_transformer + use_local_transformer=args.use_local_transformer, + legacy_codebooks=args.legacy_codebooks ) From db801cba634c3ac6de2ff6e78b120259e61c892f Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 2 May 2025 14:18:27 -0400 Subject: [PATCH 030/113] add update config to infer script (#70) * add update config to infer script Signed-off-by: Jason * Update infer_and_evaluate.py --------- Signed-off-by: Jason --- scripts/magpietts/infer_and_evaluate.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 3a442d3b12a9..2ece1daaa1db 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -47,7 +47,12 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False): if "t5_decoder" in model_cfg: model_cfg.decoder = model_cfg.t5_decoder del model_cfg.t5_decoder + if hasattr(model_cfg, 'decoder') and hasattr(model_cfg.decoder, 'prior_eps'): + # Added to prevent crash after removing arg from transformer_2501.py in https://github.com/blisc/NeMo/pull/56 + del model_cfg.decoder.prior_eps if legacy_codebooks: + # Added to address backward compatibility arising from + # https://github.com/blisc/NeMo/pull/64 print("WARNING: Using legacy codebook indices for backward compatibility. Should only be used with old checkpoints.") num_audio_tokens_per_codebook = model_cfg.num_audio_tokens_per_codebook model_cfg.forced_num_all_tokens_per_codebook = num_audio_tokens_per_codebook @@ -178,7 +183,7 @@ def run_inference( context_duration_max=context_durration_max, ) assert len(test_dataset) == len(manifest_records), "Dataset length and manifest length should be the same. Dataset length: {}, Manifest length: {}".format(len(test_dataset), len(manifest_records)) - + test_dataset.text_tokenizer = model.tokenizer test_dataset.text_conditioning_tokenizer = model.text_conditioning_tokenizer @@ -452,4 +457,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() From 789f040b2d903a37eb1d16a9c16784af593d71a6 Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 2 May 2025 15:04:06 -0400 Subject: [PATCH 031/113] Fix typo introduced during codec refactor (#72) * Fix typo introduced during codec refactor * Update text_to_speech_dataset_lhotse.py --- .../tts/data/text_to_speech_dataset_lhotse.py | 20 +++++++++---------- nemo/collections/tts/models/magpietts.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index ab8503da285f..a75dda38cd02 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -94,7 +94,7 @@ class MagpieTTSLhotseDataset(torch.utils.data.Dataset): resampled if necessary. volume_norm (bool): If True, applies peak volume normalization to audio waveforms. Defaults to True. - codec_model_downsample_factor (int): The total downsampling factor of the + codec_model_samples_per_frame (int): The total downsampling factor of the audio codec model used to generate codes. Used for padding audio and calculating number of codec frames. codec_model_name (str): Name identifier for the codec model, used to @@ -141,7 +141,7 @@ def __init__( self, sample_rate: int, volume_norm: bool = True, - codec_model_downsample_factor: int = None, + codec_model_samples_per_frame: int = None, codec_model_name: str = "21fpsCausalDecoder", audio_bos_id: int = None, audio_eos_id: int = None, @@ -169,7 +169,7 @@ def __init__( if codec_model_name not in SUPPORTED_CODEC_MODEL_NAMES: raise ValueError(f"Invalid `codec_model_name`: {codec_model_name}.") self.codec_model_name = codec_model_name - self.codec_model_downsample_factor = codec_model_downsample_factor + self.codec_model_samples_per_frame = codec_model_samples_per_frame self.num_audio_codebooks = num_audio_codebooks self.include_align_prior = prior_scaling_factor is not None @@ -186,8 +186,8 @@ def __init__( self.text_conditioning_tokenizer = None def get_num_audio_samples_to_slice(self, duration, sample_rate): - num_codec_frames = int(duration * sample_rate / self.codec_model_downsample_factor) - num_audio_samples = num_codec_frames * self.codec_model_downsample_factor + num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) + num_audio_samples = num_codec_frames * self.codec_model_samples_per_frame return num_audio_samples def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: @@ -261,11 +261,11 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: # Pad audio to be multiple of downsample factor audio = torch.nn.functional.pad( audio, - (0, self.codec_model_downsample_factor - (audio.shape[0] % self.codec_model_downsample_factor)), + (0, self.codec_model_samples_per_frame - (audio.shape[0] % self.codec_model_samples_per_frame)), value=0, ) audio_len = audio.shape[0] - spec_len = int(audio_len / self.codec_model_downsample_factor) + 1 # +1 for EOS + spec_len = int(audio_len / self.codec_model_samples_per_frame) + 1 # +1 for EOS audio_list.append(audio) audio_len_list.append(audio_len) @@ -277,7 +277,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: # Sample random duration between self.context_duration_min and self.context_duration_max _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) _num_frames_to_slice = int( - _context_duration_to_slice * self.sample_rate / self.codec_model_downsample_factor + _context_duration_to_slice * self.sample_rate / self.codec_model_samples_per_frame ) if _num_frames_to_slice < context_audio_codes.shape[1]: start_idx = random.randint(0, context_audio_codes.shape[1] - _num_frames_to_slice) @@ -349,7 +349,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: context_audio_codes_len_list.append(context_audio_codes_len) else: # @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes - context_audio = torch.zeros(self.codec_model_downsample_factor, dtype=torch.float32) + context_audio = torch.zeros(self.codec_model_samples_per_frame, dtype=torch.float32) context_audio_len = context_audio.shape[0] context_audio_list.append(context_audio) context_audio_len_list.append(context_audio_len) @@ -386,7 +386,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: has_text_context = False if self.pad_context_text_to_max_duration: _required_len = ( - int(self.context_duration_max * self.sample_rate / self.codec_model_downsample_factor) + 2 + int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 ) # +2 for BOS and EOS if len(context_text_tokens) < _required_len: _pad_id = self.text_conditioning_tokenizer.pad_token_id diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 2eea2fe9532e..013d1ef02f84 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -1590,7 +1590,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D dataset = MagpieTTSLhotseDataset( sample_rate=self.cfg.sample_rate, volume_norm=dataset_cfg.volume_norm, - codec_model_samples_per_frame=self.codec_model_samples_per_framed, + codec_model_samples_per_frame=self.codec_model_samples_per_frame, codec_model_name=self.cfg.codec_model_name, audio_bos_id=self.audio_bos_id, audio_eos_id=self.audio_eos_id, From d5cfd040c7a8f5fb811d3fccc69241d142a800d5 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Wed, 7 May 2025 08:46:48 -0700 Subject: [PATCH 032/113] Add MaskGit support for iteratively predicting codebooks (#69) * Integrate Frechet Distance Metric in infer_and_evaluate.py * Evaluate with frechet distance * MaskGit on codebook level (training) Implementation of MaskGit on codebook level - training only. * MaskGit: make local transformer non-causal * MaskGit sampling implementation * MaskGit sampling working * Fix top-k sampling bug * Maskgit changes (WIP) * Add some local manifest paths * Compatibility with embedding table layout changes * Prepare MaskGit for PR * Cleanup * Fix typo in use of SpecialAudioTokens enum * PR preparation - cleanup and comments * Remove frechet distance (not ready) * Comments * Comments and cleanup - Added docstrings - Moved cosine_schedule() function inside the MagpieTTSModel class (its only current user) * More comments and cleanup * Combine maskgit and local transformer configs We now use `local_transformer_type` to control the type of local transformer. At inference, providing '--use_local_transformer' will activate either Autoregressive or MaskGit sampling, whichever the checkpoint was trained with. * Remove deprecated `use_local_transformer` from config We now use `local_transformer_type` to control the type of local transformer (for both AR and MaskGit). * Update remaining config files to use `local_transformer_type` * infer_and_evaluate updates for MaskGit We use `--use_local_transformer` to activate either Autoregressive or MaskGit sampling, whichever the checkpoint was trained with. * Add a printout for debug * Address PR comments * Move some functionality to a new file in the modules direcotry, `magpietts_modules.py`. * Remove reduncant `remap_names` method in `magpietts.py` (already addressed in another PR). * Cleanup * Restore a comment that was accidentally removed. --- .../tts/conf/magpietts/magpietts_dc_en.yaml | 2 +- examples/tts/conf/magpietts/magpietts_en.yaml | 2 +- .../magpietts/magpietts_inference_en.yaml | 2 +- .../magpietts_inference_multilingual_v1.yaml | 2 +- .../magpietts/magpietts_lhotse_dc_en.yaml | 2 +- .../magpietts/magpietts_multilingual_v1.yaml | 2 +- nemo/collections/tts/models/magpietts.py | 287 +++++++++++++++--- .../tts/modules/magpietts_modules.py | 54 ++++ scripts/magpietts/evalset_config.py | 20 ++ scripts/magpietts/evaluate_generated_audio.py | 26 +- scripts/magpietts/infer_and_evaluate.py | 15 +- 11 files changed, 348 insertions(+), 66 deletions(-) create mode 100644 nemo/collections/tts/modules/magpietts_modules.py diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml index 517eeff8dcc0..d105c3875dbf 100644 --- a/examples/tts/conf/magpietts/magpietts_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_dc_en.yaml @@ -50,7 +50,7 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - use_local_transformer: false + local_transformer_type: "none" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 local_transformer_n_layers: 1 diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index 7dbdc7f5f822..baea9245de2e 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -52,7 +52,7 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - use_local_transformer: false + local_transformer_type: "none" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 local_transformer_n_layers: 1 diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_inference_en.yaml index 8c984a22748f..96d33405868e 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_en.yaml @@ -60,7 +60,7 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - use_local_transformer: false + local_transformer_type: "none" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 local_transformer_n_layers: 1 diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml index 0e87fd677f42..bd2ff60575d8 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml @@ -60,7 +60,7 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - use_local_transformer: false + local_transformer_type: "none" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 local_transformer_n_layers: 1 diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index 43f572e2d2e8..c062e61ece64 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -38,7 +38,7 @@ model: aligner_encoder_train_steps: 50_000 # Local transformer parameters for autoregressive codebook prediction within a frame - use_local_transformer: false + local_transformer_type: "none" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 local_transformer_n_layers: 1 diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index a735ce958e79..038cc301cded 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -53,7 +53,7 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - use_local_transformer: false + local_transformer_type: "none" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 local_transformer_n_layers: 1 diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 013d1ef02f84..bfe8d4005078 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -17,7 +17,6 @@ import string import time from typing import List -from enum import Enum, verify, CONTINUOUS, UNIQUE import librosa import numpy as np @@ -39,6 +38,7 @@ from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 from nemo.collections.tts.modules.aligner import AlignmentEncoder +from nemo.collections.tts.modules.magpietts_modules import SpecialAudioToken, LocalTransformerType, cosine_schedule from nemo.collections.tts.parts.utils.helpers import ( binarize_attention_parallel, get_mask_from_lengths, @@ -62,7 +62,6 @@ def worker_init_fn(worker_id): dataset.text_tokenizer = tokenizer dataset.text_conditioning_tokenizer = text_conditioning_tokenizer - class MagpieTTSModel(ModelPT): """ Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context @@ -98,32 +97,19 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # del codec discriminator to free memory del codec_model.discriminator - # Reserve special tokens (appended at the end of the codebook after the actual audio codec tokens) - # (the actual index is this value plus the number of codec tokens - do not use the Enum directy) - @verify(CONTINUOUS, UNIQUE) - class SpecialAudioToken(Enum): - AUDIO_BOS = 0 - AUDIO_EOS = 1 - AUDIO_CONTEXT_BOS = 2 - AUDIO_CONTEXT_EOS = 3 - MASK_TOKEN = 4 - # Reserved so that if we need to add more special tokens in the future the codebook size will remain the same - RESERVED_1 = 5 - RESERVED_2 = 6 - RESERVED_3 = 7 - # Set up codebook configuration self.num_audio_codebooks = codec_model.num_codebooks self.codec_model_samples_per_frame = codec_model.samples_per_frame # Our codebooks start with actual audio codec tokens, followed by special tokens. # The `forced_*` options are for backward compatibility for models trained with older code. num_audio_tokens = codec_model.codebook_size - self.audio_bos_id = cfg.get('forced_audio_bos_id',num_audio_tokens + SpecialAudioToken.AUDIO_BOS.value) + self.audio_bos_id = cfg.get('forced_audio_bos_id', num_audio_tokens + SpecialAudioToken.AUDIO_BOS.value) self.audio_eos_id = cfg.get('forced_audio_eos_id', num_audio_tokens + SpecialAudioToken.AUDIO_EOS.value) self.context_audio_bos_id = cfg.get('forced_context_audio_bos_id', num_audio_tokens + SpecialAudioToken.AUDIO_CONTEXT_BOS.value) self.context_audio_eos_id = cfg.get('forced_context_audio_eos_id', num_audio_tokens + SpecialAudioToken.AUDIO_CONTEXT_EOS.value) self.num_all_tokens_per_codebook = cfg.get('forced_num_all_tokens_per_codebook',num_audio_tokens + len(SpecialAudioToken)) - + self.mask_token_id = cfg.get('forced_mask_token_id', num_audio_tokens + SpecialAudioToken.MASK_TOKEN.value) + # Setup tokenizer if hasattr(cfg, 'text_tokenizer'): # For backward compatibility for English-only models @@ -172,7 +158,10 @@ class SpecialAudioToken(Enum): self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) self.final_proj = nn.Linear(cfg.decoder.d_model, self.num_audio_codebooks * self.num_all_tokens_per_codebook) - if cfg.get('use_local_transformer', False): + + self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower()) + logging.info(f"Local transformer type: {self.local_transformer_type}") + if self.local_transformer_type != LocalTransformerType.NO_LT: local_transformer_hidden_dim = cfg.get('local_transformer_hidden_dim', 256) if local_transformer_hidden_dim != cfg.decoder.d_model: self.local_transformer_in_projection = nn.Linear(cfg.decoder.d_model, local_transformer_hidden_dim) @@ -184,7 +173,7 @@ class SpecialAudioToken(Enum): d_ffn=local_transformer_hidden_dim*4, sa_n_heads=self.cfg.get('local_transformer_n_heads', 1), kernel_size=1, - is_causal=True, + is_causal=self.local_transformer_type == LocalTransformerType.AR, max_length_causal_mask=self.num_audio_codebooks+2, use_learnable_pos_emb=True, ) @@ -240,6 +229,8 @@ class SpecialAudioToken(Enum): if alignment_encoder_loss_scale > 0.0: self.alignment_encoder_loss = ForwardSumLoss(loss_scale=alignment_encoder_loss_scale) + + def state_dict(self, destination=None, prefix='', keep_vars=False): if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} @@ -250,7 +241,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): if any([substring in key for substring in keys_substrings_to_exclude]): del state_dict[key] return state_dict - + def load_state_dict(self, state_dict, strict=True): # Override to load all the keys except _speaker_verification_model and _codec_model super().load_state_dict(state_dict, strict=False) @@ -325,12 +316,28 @@ def get_speaker_embeddings(self, audio_16khz, audio_len_16khz): ) return speaker_embeddings - def compute_local_transformer_logits(self, dec_out, audio_codes_target): + def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_offset_by_one=False): """ - Loss from the autoregressive codebook predictor (used per frame) + Predicts the logits for all codebooks using the local transformer. Used in both autoregressive (AR) and MaskGit (MG) modes. + This function is used in training and validation, not inference/sampling. + The sequence layout is slightly different between AR and MG modes, as shown in the diagram below, + (using an 8-codebook setup as an example): + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + | AR target | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | none | + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + | MG target | none | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + | Input | Magpie | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + | | Latent | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + | Seq. Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + + dec_out: (B, T', E) + audio_codes_target: (B, C, T') + targets_offset_by_one: bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive) + if True, the target for index 1 is codebook 0, for index 2 is codebook 1, etc. (MaskGit) """ - # dec_out: (B, T', E) - # audio_codes: (B, C, T') dec_out_all = dec_out.reshape(-1, dec_out.size(-1)) # (B*T', E) local_transformer_input = [dec_out_all] for codebook_num in range(audio_codes_target.size(1)): @@ -343,7 +350,12 @@ def compute_local_transformer_logits(self, dec_out, audio_codes_target): local_transformer_input = self.local_transformer_in_projection(local_transformer_input) # (B*T', C+1, 128) _mask = torch.ones( local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device) local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B*T', C+1, E) - local_transformer_output = local_transformer_output[:, :-1, :] # (B*T', C, E) + if not targets_offset_by_one: + # for autoregressive local transformer the target for index 0 is codebook 0, for index 1 is codebook 1, etc. + local_transformer_output = local_transformer_output[:, :-1, :] # (B*T', C, E) + else: + # for MaskGit the target for index **1** is codebook 0, for index 2 is codebook 1, etc. + local_transformer_output = local_transformer_output[:, 1:, :] # (B*T', C, E) all_code_logits = [] for codebook_num in range(audio_codes_target.size(1)): # Using a separate projection layer for each codebook (to distinguish between them) @@ -358,11 +370,73 @@ def compute_local_transformer_logits(self, dec_out, audio_codes_target): return all_code_logits - def compute_loss(self, logits, audio_codes, audio_codes_lens): - # logits: (B, T', num_codebooks * num_tokens_per_codebook) - # audio_codes: (B, C, T') - # audio_codes_lens: (B,) + def maskgit_create_random_mask(self, codes): + """ + Creates a mask where True indicates the positions that should be replaced with a MASK_TOKEN. + """ + # Codes: (B, C, T) + B,C,T = codes.shape + # get a uniform random vector uniformly sampled from [0,1) ## Todo does it need to be inclusive on the right? + rand_values = torch.rand(B,T, device=codes.device) + # apply the cosine schedule + frac_masked = cosine_schedule(rand_values) + # how many positions to mask + n_masked = torch.ceil(frac_masked * C).long() # B,T + # start from all unmasked + mask = torch.zeros_like(codes, dtype=torch.bool) + # The code further below is the vectorized version of this: + # for b in range(B): + # for t in range(T): + # if n_masked[b,t] > 0: + # # get a random permutation of the codebook indices + # perm = torch.randperm(C) + # # mask the top n_masked positions + # mask[b, perm[:n_masked[b,t]], t] = True + # + # Create random permutations + random_permutations = torch.argsort(torch.rand(B, C, T, device=codes.device), dim=1) # (B, C, T) + # Create a mask tensor where each position indicates if it should be masked + mask_indices = torch.arange(C, device=codes.device).view(1, C, 1) + mask = mask_indices < n_masked.view(B, 1, T) # (B, C, T) + # Apply the random permutations to the mask + mask = torch.gather(mask, 1, random_permutations) + + return mask # (B, C, T) + + def maskgit_apply_random_mask(self, codes): + # Randomly replaces some codes with the MASK_TOKEN with a proportion following the cosine schedule. + # Codes: (B, C, T) + mask = self.maskgit_create_random_mask(codes) + ## replace some tokens with MASK_TOKEN + codes_with_mask = torch.where(mask, self.mask_token_id, codes) + return codes_with_mask, mask + + def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=None): + """ + Computes the audio codebook loss. Used by + (1) The main Magpie-TTS transformer + (2) The local transformer, for both autoregressive and MaskGit methods + + logits: (B, T', num_codebooks * num_tokens_per_codebook) + audio_codes: (B, C, T') + audio_codes_lens: (B,) + mask_tokens_mask: (B, C, T') True for tokens that were replaced with the MASK_TOKEN and should + therefore be the only ones included in the loss computation. + """ loss_mask = get_mask_from_lengths(audio_codes_lens) + if mask_tokens_mask is not None: + # For MaskGit we only compute loss for the masked tokens. + # *Both* conditions must be true: + # 1. the token is masked + # 2. the token is not padding + loss_mask = loss_mask.unsqueeze(1) * mask_tokens_mask + if not loss_mask.any(): + # Without this we were very rarely getting NaNs in the loss + logging.warning("No tokens valid were found in compute_loss()!") + return torch.tensor(0.0, device=loss_mask.device), loss_mask + else: + # repeat loss mask for each codebook to simplify code below + loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1) total_codebook_loss = None for codebook in range(audio_codes.size(1)): si = codebook * self.num_all_tokens_per_codebook @@ -372,8 +446,8 @@ def compute_loss(self, logits, audio_codes, audio_codes_lens): codebook_loss = self.cross_entropy_loss( codebook_logits.permute(0, 2, 1), codebook_targets # (B, num_tokens_per_codebook, T') ) # (B, T') - codebook_loss = codebook_loss * loss_mask - codebook_loss = codebook_loss.sum() / loss_mask.sum() + codebook_loss = codebook_loss * loss_mask[:, codebook, :] + codebook_loss = codebook_loss.sum() / loss_mask[:, codebook, :].sum() if total_codebook_loss is None: total_codebook_loss = codebook_loss else: @@ -414,7 +488,99 @@ def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): return all_preds - def sample_codes_from_local_transformer(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0): + def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0, n_steps=3): + """ + Sample codes for one timestep from the local transformer using MaskGit. + """ + # dec_output: (B, E) + device = dec_output.device + # disable KV cache since our transformer is not causal + self.local_transformer.reset_cache(use_cache=False) + dec_output = dec_output.unsqueeze(1) # (B, 1, E) + local_transformer_input_init = self.local_transformer_in_projection(dec_output) # (B, 1, D) where D is the dimension of the local transformer + C = self.num_audio_codebooks + B = dec_output.size(0) + + min_confidence = float("-inf") + max_confidence = 10000 # this needs to be large enough that unmasked items will always remain unmasked. # TODO @rfejgin: use float('inf')? + confidences = min_confidence * torch.ones(B, C, device=device) + # initialize to all masked + codes = self.mask_token_id * torch.ones((B, C), device=device, dtype=torch.long) + sampled_codes = codes.clone() + for step in range(n_steps): + # get mask fraction + frac_masked = cosine_schedule(torch.tensor(step / (n_steps))) + # how many codebooks to mask + n_masked = torch.ceil(C * frac_masked).long() # TODO @rfejgin: should we force this to be initialized to exactly `C` (to avoid numerical issues)? + n_unmasked = C - n_masked + # pick top-confidence codebooks up to n_unmasked + _, topk_indices = torch.topk(confidences, k=n_unmasked, dim=1) + + # replace masks of the top-k confident codebooks with the the codes that were sampled for them + unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices) + codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) + + # build transformer input + local_transformer_input = local_transformer_input_init + for codebook_num in range(C): + next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze(1) # (B, 1, 768) + next_local_transformer_input = self.local_transformer_in_projection(next_local_transformer_input) # (B, 1, d_local) + local_transformer_input = torch.cat([local_transformer_input, next_local_transformer_input], dim=1) # (B, codebook_num+1, d_local) + + # run transformer + _mask = torch.ones(B, C+1, device=device) + local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, C+1, d_local) + + # get logits + logits = [] + for codebook_num in range(C): + # The `codebook_num+1` is to drop first position which corresponds to the magpie latent + codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, codebook_num+1, :]) # (B, num_audio_tokens_per_codebook) + logits.append(codebook_logits) + logits = torch.stack(logits, dim=1) # (B, C, num_audio_tokens_per_codebook) + + # apply CFG + if use_cfg: + actual_batch_size = logits.size(0) // 2 + conditional_logits = logits[:actual_batch_size] + unconditional_logits = logits[actual_batch_size:] + cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits + logits[:actual_batch_size] = cfg_logits + + # handle unfinished and finished items + for item_idx in unfinished_items: + logits[item_idx, self.audio_eos_id] = float('-inf') + for item_idx in finished_items: + logits[item_idx, :, :] = float('-inf') + logits[item_idx, :, self.audio_eos_id] = 0.0 + + # sample with top-k + logits_topk = torch.topk(logits, topk, dim=-1)[0] # (B, C, topk) + indices_to_remove = logits < logits_topk[:, :, -1].unsqueeze(-1) # (B, C, num_audio_tokens_per_codebook) + logits_rescored = logits.clone() + logits_rescored[indices_to_remove] = float('-inf') + probs = torch.softmax(logits_rescored / temperature, dim=-1) # (B, C, num_audio_tokens_per_codebook) + sampled_codes = torch.multinomial(probs.view(B*C, -1), 1).view(B, C) + if use_cfg: + # TODO @rfejgin: why do we need to keep second half of the batch? can probably optimize this + sampled_codes[actual_batch_size:] = sampled_codes[:actual_batch_size] + probs[actual_batch_size:] = probs[:actual_batch_size] + confidences = torch.gather(probs, dim=2, index=sampled_codes.unsqueeze(-1)).squeeze(-1) + + # set confidence to max for unmasked codebooks so that they will remain unmasked + confidences.scatter_(index=topk_indices, dim=1, src=max_confidence*torch.ones_like(topk_indices, dtype=torch.float)) + + # replace entries in sampled_codes with previously unmasked codebooks + sampled_codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) + # optionally: add noise to confidences here (as in token-critic paper) (not implemented) + + codes = sampled_codes + assert not (codes == self.mask_token_id).any(), f"Codes contain mask tokens after completion of MaskGit sampling" + if use_cfg: + codes = codes[:actual_batch_size] + return codes + + def local_transformer_sample_autoregressive(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0): # dec_output: (B, E) self.local_transformer.reset_cache(use_cache=True) dec_output = dec_output.unsqueeze(1) # (B, 1, E) @@ -979,9 +1145,18 @@ def process_batch(self, batch, mode="train"): local_transformer_loss = None local_transformer_logits = None - if self.cfg.get('use_local_transformer', False): - local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_target) - local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target, audio_codes_lens_target) + if self.local_transformer_type != LocalTransformerType.NO_LT: + if self.local_transformer_type == LocalTransformerType.MASKGIT: + # randomly replace some positions with MASK_TOKEN + audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target) + local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_masked, targets_offset_by_one=True) + #audio_codes_masked = audio_codes_masked[:, 1:, :] + local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target, audio_codes_lens_target, mask_tokens_mask) + else: + # autoregressive + assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type" + local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_target, targets_offset_by_one=False) + local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target, audio_codes_lens_target, None) local_transformer_loss_scale = self.cfg.get('local_transformer_loss_scale', 1.0) loss = loss + local_transformer_loss_scale * local_transformer_loss @@ -1281,6 +1456,7 @@ def infer_batch( start_prior_after_n_audio_steps=10, compute_all_heads_attn_maps=False, use_local_transformer_for_inference=False, + maskgit_n_steps=3 ): with torch.no_grad(): start_time = time.time() @@ -1430,18 +1606,35 @@ def infer_batch( unifinished_items = {k: v for k, v in unfinished_texts.items() if v} all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) - if self.cfg.get('use_local_transformer', False) and use_local_transformer_for_inference: - audio_codes_next = self.sample_codes_from_local_transformer( - dec_output=dec_out[:,-1,:], - temperature=temperature, - topk=topk, - unfinished_items=unifinished_items, - finished_items=finished_items, - use_cfg=use_cfg, - cfg_scale=cfg_scale - ) + if use_local_transformer_for_inference: + if self.local_transformer_type == LocalTransformerType.AR : + # Autoregressive sampling with local transformer + audio_codes_next = self.local_transformer_sample_autoregressive( + dec_output=dec_out[:,-1,:], + temperature=temperature, + topk=topk, + unfinished_items=unifinished_items, + finished_items=finished_items, + use_cfg=use_cfg, + cfg_scale=cfg_scale + ) + elif self.local_transformer_type == LocalTransformerType.MASKGIT: + audio_codes_next = self.local_transformer_sample_maskgit( + dec_output=dec_out[:,-1,:], + temperature=temperature, + topk=topk, + unfinished_items=unifinished_items, + finished_items=finished_items, + use_cfg=use_cfg, + cfg_scale=cfg_scale, + n_steps=maskgit_n_steps + ) + else: + raise ValueError(f"Local transformer inference requested by but local transformer type is {self.local_transformer_type}") + # TODO @rfejgin: should we add argmax sampling for EOS here too? all_codes_next_argmax = audio_codes_next else: + # Parallel sampling from all codebooks audio_codes_next = self.sample_codes_from_logits(all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) all_codes_next_argmax = self.sample_codes_from_logits(all_code_logits_t, temperature=0.01, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) @@ -1554,7 +1747,7 @@ def on_validation_epoch_end(self): self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True) - if self.cfg.get('use_local_transformer', False): + if self.local_transformer_type != LocalTransformerType.NO_LT: val_local_transformer_loss = collect("val_local_transformer_loss") self.log("val/local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True) self.validation_step_outputs.clear() # free memory diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py new file mode 100644 index 000000000000..72fc063cce65 --- /dev/null +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum, StrEnum, verify, CONTINUOUS, UNIQUE +import torch + + +class LocalTransformerType(StrEnum): + """ + Enum for the type of local transformer to use in the MagpieTTS model. + These strings are the values allowed in the YAML config file. + """ + + NO_LT = "none" + AR = "autoregressive" + MASKGIT = "maskgit" + + +@verify(CONTINUOUS, UNIQUE) +class SpecialAudioToken(Enum): + """ + Enum for the special tokens to use in the MagpieTTS model. + The special tokens are appended at the end of the codebook after the actual audio codec tokens. + The actual codeco index is this value below plus the number of codec tokens - do not use the Enum directly. + """ + + AUDIO_BOS = 0 + AUDIO_EOS = 1 + AUDIO_CONTEXT_BOS = 2 + AUDIO_CONTEXT_EOS = 3 + MASK_TOKEN = 4 + # Reserve these values so that if we need to add more special tokens in the future the codebook size will remain the same + RESERVED_1 = 5 + RESERVED_2 = 6 + RESERVED_3 = 7 + + +def cosine_schedule(x: torch.Tensor): + """ + Maps input values from [0, 1] to [1, 0] using the first quadrant of the cosine function. + Used for MaskGit mask scheduling. + """ + return torch.cos(x * (torch.pi / 2)) diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index b44c1601456a..ad8f999868c5 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -20,6 +20,21 @@ 'audio_dir' : '/datap/misc/Datasets/riva', 'feature_dir' : '/datap/misc/Datasets/riva', }, + 'libri_dev_clean_eval_large': { + 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/dev_clean_withContextAudioPaths_withTargetCodes_evalset_large.json', + 'audio_dir' : '/datap/misc/Datasets/LibriTTS', + 'feature_dir' : '/datap/misc/Datasets/LibriTTS', + }, + 'libri_dev_clean_eval_mid': { + 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/dev_clean_withContextAudioPaths_withTargetCodes_evalset_mid.json', + 'audio_dir' : '/datap/misc/Datasets/LibriTTS', + 'feature_dir' : '/datap/misc/Datasets/LibriTTS', + }, + 'libri_dev_clean_eval_tiny': { + 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/dev_clean_withContextAudioPaths_withTargetCodes_evalset_tiny.json', + 'audio_dir' : '/datap/misc/Datasets/LibriTTS', + 'feature_dir' : '/datap/misc/Datasets/LibriTTS', + }, 'libri_val': { 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360_val.json', 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', @@ -47,6 +62,11 @@ 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', 'load_cached_codes_if_available': False }, + 'riva_val_text_context': { + 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/RivattsEnglishLindyRodney21fps_val_nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_phoneme_tts_TextContext.json', + 'audio_dir' : "/datap/misc/Datasets/riva/RivattsEnglish", + 'feature_dir' : '/', + }, 'libri_unseen_test_shehzeen_phoneme': { 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', 'audio_dir' : '/Data/LibriTTS', diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 2d946d3a5dd6..74d862ec685e 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -127,18 +127,24 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo context_audio_filepath = os.path.join(audio_dir, context_audio_filepath) pred_audio_filepath = audio_file_lists[ridx] - if language == "en": - with torch.no_grad(): - # import ipdb; ipdb.set_trace() - pred_text = asr_model.transcribe([pred_audio_filepath])[0].text + + try: + if language == "en": + with torch.no_grad(): + # import ipdb; ipdb.set_trace() + pred_text = asr_model.transcribe([pred_audio_filepath])[0].text + pred_text = process_text(pred_text) + gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0].text + gt_audio_text = process_text(gt_audio_text) + else: + pred_text = transcribe_with_whisper(whisper_model, whisper_processor, pred_audio_filepath, language, device) pred_text = process_text(pred_text) - gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0].text + gt_audio_text = transcribe_with_whisper(whisper_model, whisper_processor, gt_audio_filepath, language, device) gt_audio_text = process_text(gt_audio_text) - else: - pred_text = transcribe_with_whisper(whisper_model, whisper_processor, pred_audio_filepath, language, device) - pred_text = process_text(pred_text) - gt_audio_text = transcribe_with_whisper(whisper_model, whisper_processor, gt_audio_filepath, language, device) - gt_audio_text = process_text(gt_audio_text) + except Exception as e: + print("Error during ASR: {}".format(e)) + pred_text = "" + gt_audio_text = "" if 'normalized_text' in record: gt_text = process_text(record['normalized_text']) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 2ece1daaa1db..469485671be7 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -61,6 +61,7 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False): if model_cfg.model_type == 'decoder_context_tts': model_cfg.forced_context_audio_eos_id = num_audio_tokens_per_codebook - 3 model_cfg.forced_context_audio_bos_id = num_audio_tokens_per_codebook - 4 + model_cfg.forced_mask_token_id = num_audio_tokens_per_codebook - 5 else: model_cfg.forced_context_audio_eos_id = num_audio_tokens_per_codebook - 1 model_cfg.forced_context_audio_bos_id = num_audio_tokens_per_codebook - 2 @@ -90,6 +91,7 @@ def run_inference( start_prior_after_n_audio_steps=10, confidence_level=0.95, use_local_transformer=False, + maskgit_n_steps=3, legacy_codebooks=False ): # Load model @@ -123,7 +125,8 @@ def run_inference( model.cuda() model.eval() - checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}_LT_{}_sv_{}".format( + checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] + checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}_LT_{}_MGsteps{}_sv_{}".format( checkpoint_name, temperature, topk, @@ -136,6 +139,7 @@ def run_inference( "".join([str(l) for l in estimate_alignment_from_layers]) if estimate_alignment_from_layers is not None else "None", "".join([str(l) for l in apply_prior_to_layers]) if apply_prior_to_layers is not None else "None", use_local_transformer, + maskgit_n_steps, sv_model ) dataset_meta_info = evalset_config.dataset_meta_info @@ -222,7 +226,8 @@ def run_inference( estimate_alignment_from_layers=estimate_alignment_from_layers, apply_prior_to_layers=apply_prior_to_layers, start_prior_after_n_audio_steps=start_prior_after_n_audio_steps, - use_local_transformer_for_inference=use_local_transformer + use_local_transformer_for_inference=use_local_transformer, + maskgit_n_steps=maskgit_n_steps ) all_rtf_metrics.append(rtf_metrics) et = time.time() @@ -307,7 +312,8 @@ def main(): parser.add_argument('--out_dir', type=str, default="/datap/misc/Evals/LocalTransformerAblations2") parser.add_argument('--temperature', type=float, default=0.6) parser.add_argument('--use_cfg', action='store_true') - parser.add_argument('--use_local_transformer', action='store_true') + parser.add_argument('--use_local_transformer', action='store_true', help="Enables use of local transformer for inference; applies to both Autoregressive and MaskGit sampling.") + parser.add_argument('--maskgit_n_steps', type=int, default=3) parser.add_argument('--cfg_scale', type=float, default=2.5) parser.add_argument('--apply_attention_prior', action='store_true') parser.add_argument('--attention_prior_epsilon', type=float, default=1e-3) @@ -361,6 +367,7 @@ def main(): start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, use_local_transformer=args.use_local_transformer, + maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks ) return @@ -390,6 +397,7 @@ def main(): start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, use_local_transformer=args.use_local_transformer, + maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks ) else: @@ -452,6 +460,7 @@ def main(): start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, use_local_transformer=args.use_local_transformer, + maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks ) From ef5efde45bd4756d1df8992ce1fc7022edf697f6 Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Wed, 7 May 2025 08:58:28 -0700 Subject: [PATCH 033/113] Add spectral codec modules (#74) * Add spectral codec modules Signed-off-by: Ryan * Add codec module documentation Signed-off-by: Ryan --------- Signed-off-by: Ryan --- .../tts/modules/audio_codec_modules.py | 452 ++++++++++++++++++ 1 file changed, 452 insertions(+) diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index 846d5a028332..b9133ac2ea62 100755 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -663,6 +663,7 @@ def __init__( dilation: int = 1, padding: Optional[int] = None, pad_mode: str = "reflect", + activation: Optional[str] = None, ): super().__init__() if not padding: @@ -677,6 +678,11 @@ def __init__( padding_mode=pad_mode, ) self.conv = nn.utils.parametrizations.weight_norm(conv) + if activation is not None: + self.activation = CodecActivation(activation=activation, channels=out_channels) + else: + self.activation = torch.nn.Identity() + @property def input_types(self): @@ -697,6 +703,7 @@ def remove_weight_norm(self): @typecheck() def forward(self, inputs, input_len): out = self.conv(inputs) + out = self.activation(out) out = mask_sequence_tensor(out, input_len) return out @@ -1490,6 +1497,34 @@ def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor return dequantized + @typecheck( + input_types={ + "codes": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "indices": NeuralType(('B', 'D', 'T'), Index()), + }, + ) + def codes_to_indices(self, codes: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: + """Converts a code vector to indices. + """ + codes_rearrange = rearrange(codes, 'B D T -> D B T') + codes_grouped = codes_rearrange.chunk(self.num_groups, dim=0) + indices = [] + + for codes_group, fsq_group in zip(codes_grouped, self.fsqs): + codes_group_rearrange = rearrange(codes_group, 'D B T -> B D T') + # [B, T] + indices_group = fsq_group.codes_to_indices(codes=codes_group_rearrange) + indices_group = mask_sequence_tensor(indices_group, input_len) + indices.append(indices_group) + + # concatenate along the feature dimension + indices = torch.stack(indices, dim=1) + + return indices + class ResidualBlock(NeuralModule): """ @@ -1566,6 +1601,53 @@ def forward(self, inputs, input_len): return out +class ResidualBlockV2(NeuralModule): + """ + Residual block which applies activation to output instead of input. + + Args: + channels: Input dimension. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + activation: Name of activation function. + """ + + def __init__( + self, + channels: int, + filters: int, + kernel_size: int = 3, + activation: str = "lrelu", + ): + super(ResidualBlockV2, self).__init__() + + self.input_conv = Conv1dNorm( + in_channels=channels, out_channels=filters, kernel_size=kernel_size, activation=activation + ) + self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) + self.output_activation = CodecActivation(activation=activation, channels=channels) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} + + @property + def output_types(self): + return {"out": NeuralType(('B', 'C', 'T'), EncodedRepresentation())} + + @typecheck() + def forward(self, inputs, input_len): + res = self.input_conv(inputs=inputs, input_len=input_len) + res = self.skip_conv(inputs=res, input_len=input_len) + out = inputs + res + out = self.output_activation(out) + return out + + class HiFiGANResBlock(NeuralModule): """ Residual block wrapper for HiFi-GAN which creates a block for multiple dilations. @@ -2242,6 +2324,61 @@ def forward(self, audio, audio_len): return spec, spec_len +class STFTProcessor(NeuralModule): + """ + Interface for computing log magnitude STFT features. + + Args: + n_fft: Size of Fourier transform + win_length: The size of the sliding window frames for windowing and STFT. + hop_length: The distance between neighboring sliding window frames + log_guard: Value to add to magnitude STFT before taking log. + """ + + def __init__(self, n_fft, win_length, hop_length, log_guard=1.0): + super(STFTProcessor, self).__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) + self.log_guard = log_guard + self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), + "spec_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + spec_len = audio_len // self.hop_length + audio_padded = torch.nn.functional.pad(audio, (self.stft_pad_amount, self.stft_pad_amount), "reflect") + # [B, n_fft, T_spec] + fft = torch.stft( + audio_padded, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + return_complex=True, + center=False + ) + fft_mag = torch.abs(fft) + fft_mag_log = torch.log(fft_mag + self.log_guard) + fft_mag_log = mask_sequence_tensor(fft_mag_log, spec_len) + return fft_mag_log, spec_len + + class ResNetEncoder(NeuralModule): """ Residual network which uses HiFi-GAN residual blocks to encode spectrogram features without changing @@ -2423,3 +2560,318 @@ def forward(self, audio, audio_len): # [B, C, T] encoded = torch.cat(outputs, dim=1) return encoded, spec_len + + +class STFTResidualBlock(NeuralModule): + """ + Block in multi-resolution STFT encoder which adds an STFT resolution to the encoder latent space, after down + sampling the input to match the time resoluton of the STFT features. + + Args: + resolution: STFT resolution, formatted as a 3-tuple (n_fft, hop_length, window_size) + input_dim: Dimension if input latenct features. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + activation: Name of activation function. + down_sample_rate: Down sample factor to reduce input by before adding STFT encoding. + """ + + def __init__( + self, + resolution: Tuple[int], + input_dim: int, + filters: int, + kernel_size: int, + activation: str, + down_sample_rate: int, + ): + super(STFTResidualBlock, self).__init__() + down_sample_kernel_size = down_sample_rate * 2 + 1 + + self.down_sample_rate = down_sample_rate + self.down_sample_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=filters, + kernel_size=down_sample_kernel_size, + stride=self.down_sample_rate, + activation=activation + ) + + n_fft, hop_length, win_length = resolution + stft_dim = n_fft // 2 + 1 + self.spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) + self.spec_conv = Conv1dNorm(in_channels=stft_dim, out_channels=filters, kernel_size=kernel_size) + self.spec_act = CodecActivation(activation=activation, channels=filters) + + self.res_block = ResidualBlockV2( + channels=filters, filters=filters, kernel_size=kernel_size, activation=activation + ) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), + "out_len": NeuralType(tuple('B'), LengthsType()) + } + + @typecheck() + def forward(self, inputs, input_len, audio, audio_len): + out_len = input_len // self.down_sample_rate + out = self.down_sample_conv(inputs=inputs, input_len=out_len) + + spec, _ = self.spec_processor(audio=audio, audio_len=audio_len) + spec_res = self.spec_conv(inputs=spec, input_len=out_len) + out = out + spec_res + out = self.spec_act(out) + + out = self.res_block(inputs=out, input_len=out_len) + return out, out_len + + +class DownSampleResidualBlock(NeuralModule): + """ + Layer which combines a down sampling layer with a residual block. + + Args: + channels: Input dimension. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + activation: Activation to apply in between residual convolutions. + down_sample_rate: Factor to down sample time dimension by. + """ + + def __init__( + self, + channels: int, + filters: int, + kernel_size: int, + activation: str, + down_sample_rate: int, + ): + super(DownSampleResidualBlock, self).__init__() + down_sample_kernel_size = down_sample_rate * 2 + 1 + + self.down_sample_rate = down_sample_rate + self.down_sample_conv = Conv1dNorm( + in_channels=channels, + out_channels=filters, + kernel_size=down_sample_kernel_size, + stride=self.down_sample_rate, + activation=activation + ) + self.res_block = ResidualBlockV2( + channels=filters, filters=filters, kernel_size=kernel_size, activation=activation + ) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()) + } + + @property + def output_types(self): + return { + "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), + "out_len": NeuralType(tuple('B'), LengthsType()) + } + + @typecheck() + def forward(self, inputs, input_len): + output_len = input_len // self.down_sample_rate + out = self.down_sample_conv(inputs=inputs, input_len=output_len) + out = self.res_block(inputs=out, input_len=output_len) + return out, output_len + + +class MultiResolutionSTFTEncoder(NeuralModule): + """ + Interface for computing log magnitude STFT features. + + Args: + out_dim: Dimension of encoder output embedding. + resolutions: List of STFT resolutions, formatted as 3-tuples (n_fft, hop_length, window_size) + resolution_filter_list: List the same size as 'resolutions', specifying the number of filters in the residual + block for each STFT resolution. + down_sample_filter_list: List of filters to use for each down sampling block after initial STFT encoding. + down_sample_rate_list: List of rates to use for each down sampling block after initial STFT encoding. + The total down sample rate of the encoder will be 2**(len(resolutions)) * product(down_sample_rate_list) + kernel_size: Kernel size to use in all convolutions. + activation: Name of activation function. + """ + + def __init__( + self, + out_dim: int, + resolutions: List[Tuple[int]], + resolution_filter_list: List[int], + down_sample_filter_list: Tuple[int] = (), + down_sample_rate_list: Tuple[int] = (), + kernel_size: int = 3, + activation: str = "lrelu", + ): + super(MultiResolutionSTFTEncoder, self).__init__() + assert len(resolutions) >= 1 + assert len(resolutions) == len(resolution_filter_list) + + n_fft, hop_length, win_length = resolutions[0] + input_filters = resolution_filter_list[0] + input_dim = n_fft // 2 + 1 + self.pre_spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) + self.pre_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=input_filters, + kernel_size=kernel_size, + activation=activation + ) + self.pre_res_block = ResidualBlockV2( + channels=input_filters, + filters=input_filters, + kernel_size=kernel_size, + activation=activation + ) + input_dim = input_filters + self.stft_blocks = nn.ModuleList([]) + for resolution, filters in zip(resolutions[1:], resolution_filter_list[1:]): + stft_block = STFTResidualBlock( + resolution=resolution, + input_dim=input_dim, + down_sample_rate=2, + filters=filters, + kernel_size=kernel_size, + activation=activation, + ) + self.stft_blocks.append(stft_block) + input_dim = filters + + if down_sample_filter_list and not down_sample_rate_list: + down_sample_rate_list = len(down_sample_filter_list) * [2] + + self.down_sample_blocks = nn.ModuleList([]) + for (filters, down_sample_rate) in zip(down_sample_filter_list, down_sample_rate_list): + down_sample_block = DownSampleResidualBlock( + channels=input_dim, + filters=filters, + down_sample_rate=down_sample_rate, + kernel_size=kernel_size, + activation=activation + ) + self.down_sample_blocks.append(down_sample_block) + input_dim = filters + + self.post_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=out_dim, + kernel_size=kernel_size + ) + + def remove_weight_norm(self): + self.encoder.remove_weight_norm() + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + encoded, encoded_len = self.pre_spec_processor(audio=audio, audio_len=audio_len) + encoded = self.pre_conv(inputs=encoded, input_len=encoded_len) + encoded = self.pre_res_block(inputs=encoded, input_len=encoded_len) + + for stft_block in self.stft_blocks: + encoded, encoded_len = stft_block( + inputs=encoded, input_len=encoded_len, audio=audio, audio_len=audio_len + ) + + for down_sample_block in self.down_sample_blocks: + encoded, encoded_len = down_sample_block(inputs=encoded, input_len=encoded_len) + + encoded = self.post_conv(inputs=encoded, input_len=encoded_len) + + return encoded, encoded_len + + +class VectorQuantizerIndexConverter(NeuralModule): + """ + Utility for converting indices between two FSQ definitions. + + Example: + + from nemo.collections.tts.models import AudioCodecModel + from nemo.collections.tts.modules.audio_codec_modules import GroupFiniteScalarQuantizer, VectorQuantizerIndexConverter + + audio_file = "/home/audio.wav" + codec_path = "/home/SpectralCodecFps43.nemo" + + device = "cuda:0" + + audio, _ = librosa.load(audio_file, sr=sample_rate) + + audio_tensor = torch.tensor([audio]).to(device) + audio_len_tensor = torch.tensor([audio.shape[0]]).to(device) + + codec_model = AudioCodecModel.restore_from(codec_path, map_location=device) + tokens, token_len = codec_model.encode(audio=audio_tensor, audio_len=audio_len_tensor) + + fsq_new = GroupFiniteScalarQuantizer(num_groups=6, num_levels_per_group=[5, 5, 5, 5]).to(device) + + # vector_quantizer_original has 4 codebooks with 6 levels [5, 5, 5, 5, 5, 5] + # vector_quantizer_new has 6 codebooks with 4 levels [5, 5, 5, 5] + fsq_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=fsq_new + ) + + tokens_new = fsq_converter.convert_original_to_new(audio_tokens=tokens, audio_lens=token_len) + tokens_original = fsq_converter.convert_new_to_original(audio_tokens=tokens_new, audio_lens=token_len) + + """ + + def __init__(self, vector_quantizer_original, vector_quantizer_new): + super().__init__() + self.vector_quantizer_original = vector_quantizer_original + self.vector_quantizer_new = vector_quantizer_new + + # Input [batch, num_codebooks_original, time] + # Output [batch, num_codebooks_new, time] + def convert_original_to_new(self, audio_tokens, audio_lens): + audio_tokens_rearrange = rearrange(audio_tokens, 'B C T -> C B T') + audio_codes = self.vector_quantizer_original.decode(indices=audio_tokens_rearrange, input_len=audio_lens) + audio_tokens_new = self.vector_quantizer_new.codes_to_indices(codes=audio_codes, input_len=audio_lens) + return audio_tokens_new + + # Input [batch, num_codebooks_new, time] + # Output [batch, num_codebooks_original, time] + def convert_new_to_original(self, audio_tokens, audio_lens): + audio_tokens_rearrange = rearrange(audio_tokens, 'B C T -> C B T') + audio_codes = self.vector_quantizer_new.decode(indices=audio_tokens_rearrange, input_len=audio_lens) + audio_tokens_original = self.vector_quantizer_original.codes_to_indices(codes=audio_codes, input_len=audio_lens) + return audio_tokens_original From 38234592b6a921aae2e9966ba5862f495ead7e3a Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 7 May 2025 14:18:38 -0400 Subject: [PATCH 034/113] Update Checkpoint Loading (#73) * add data Signed-off-by: Jason * update the checkpoint loading script Signed-off-by: Jason * add func comments and address review comments Signed-off-by: Jason --------- Signed-off-by: Jason --- nemo/collections/tts/models/magpietts.py | 30 ++++++++++++++++++++---- scripts/magpietts/evalset_config.py | 5 ++++ scripts/magpietts/infer_and_evaluate.py | 16 ++++++++++++- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index bfe8d4005078..f3821d88fa48 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -100,7 +100,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # Set up codebook configuration self.num_audio_codebooks = codec_model.num_codebooks self.codec_model_samples_per_frame = codec_model.samples_per_frame - # Our codebooks start with actual audio codec tokens, followed by special tokens. + # Our codebooks start with actual audio codec tokens, followed by special tokens. # The `forced_*` options are for backward compatibility for models trained with older code. num_audio_tokens = codec_model.codebook_size self.audio_bos_id = cfg.get('forced_audio_bos_id', num_audio_tokens + SpecialAudioToken.AUDIO_BOS.value) @@ -141,7 +141,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if self.use_text_conditioning_encoder: self.context_text_embedding = nn.Embedding(self.text_conditioning_tokenizer.vocab_size, cfg.embedding_dim) - + # This needs to happen after super().__init__() self._codec_model = codec_model self._codec_model.freeze() #Lightning does requires_grad = False and self.eval() @@ -232,6 +232,10 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): def state_dict(self, destination=None, prefix='', keep_vars=False): + """ + Only used for saving checkpoints. On save, we remove _speaker_verification_model and _codec_model + from the checkpoint. The codec model is saved in a separate checkpoint. + """ if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} # Don't save the speaker verification and codec model in the state dict @@ -243,8 +247,26 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): return state_dict def load_state_dict(self, state_dict, strict=True): - # Override to load all the keys except _speaker_verification_model and _codec_model - super().load_state_dict(state_dict, strict=False) + """ + Modify load_state_dict so that we don't restore weights to _speaker_verification_model and _codec_model when + strict is True. + When strict is False, we can call pytorch's load_state_dict. + When strict is True, we loop through all parameters and rename them to enable loading. + """ + if strict == False: + super().load_state_dict(state_dict, strict=False) + for name, child in self.named_children(): + if name in ['_speaker_verification_model', '_codec_model']: + continue + if any(param.numel() > 0 for param in child.parameters()): + # If the module has parameters, we want to change the default mapping so that the state_dict gets + # loaded. + # Ex: state_dict[encoder.position_embeddings.weight] -> new_state_dict[position_embeddings.weight] + new_state_dict = {} + for key in state_dict.keys(): + if key.startswith(name): + new_state_dict[key[len(name)+1:]] = state_dict[key] # +1 for '.' + child.load_state_dict(new_state_dict) def audio_to_codes(self, audio, audio_len, audio_type='target'): # audio: (B, T) diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index ad8f999868c5..8436df0538fe 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -264,4 +264,9 @@ 'whisper_language': 'pl', 'load_cached_codes_if_available': False }, + 'j_libri_unseen_test_no_codes': { + 'manifest_path' : '/home/jasoli/data_prime/manifests/test_clean_withContextAudioPaths.json', + 'audio_dir' : '/mnt/drive1/data/LibriTTS/', + 'feature_dir' : None, + }, } \ No newline at end of file diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 469485671be7..69a255c18c15 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -68,6 +68,19 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False): return model_cfg +def update_ckpt(state_dict): + new_state_dict = {} + for key in state_dict.keys(): + if 't5_encoder' in key: + new_key = key.replace('t5_encoder', 'encoder') + new_state_dict[new_key] = state_dict[key] + elif 't5_decoder' in key: + new_key = key.replace('t5_decoder', 'decoder') + new_state_dict[new_key] = state_dict[key] + else: + new_state_dict[key] = state_dict[key] + return new_state_dict + def run_inference( hparams_file, checkpoint_file, @@ -109,7 +122,8 @@ def run_inference( # Load weights from checkpoint file print("Loading weights from checkpoint") ckpt = torch.load(checkpoint_file, weights_only=False) - model.load_state_dict(ckpt['state_dict']) + state_dict = update_ckpt(ckpt['state_dict']) + model.load_state_dict(state_dict) checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] elif nemo_file is not None: model_cfg = MagpieTTSModel.restore_from(nemo_file, return_config=True) From e7723c502edbd4f94e076f1e61b752a7b3afd794 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 13 May 2025 06:59:54 -0700 Subject: [PATCH 035/113] Bugfix in load_state_dict() (#13555) * Bugfix in load_state_dict() The check for child parameters would sometimes detect incorrect keys. For example, a key named `local_transformer_out_projections.2.weight` would get detected as a child of the `local_transformer` module (but it isn't). Signed-off-by: Fejgin, Roy * Make previous commit a bit neater Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index f3821d88fa48..e155f4351032 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -264,8 +264,9 @@ def load_state_dict(self, state_dict, strict=True): # Ex: state_dict[encoder.position_embeddings.weight] -> new_state_dict[position_embeddings.weight] new_state_dict = {} for key in state_dict.keys(): - if key.startswith(name): - new_state_dict[key[len(name)+1:]] = state_dict[key] # +1 for '.' + name_with_dot = f"{name}." + if key.startswith(name_with_dot): + new_state_dict[key[len(name_with_dot):]] = state_dict[key] child.load_state_dict(new_state_dict) def audio_to_codes(self, audio, audio_len, audio_type='target'): From f8b70be4cda805015f79a18a7e64b012b0a06a96 Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Wed, 21 May 2025 08:05:49 -0700 Subject: [PATCH 036/113] Magpie yaml updates for LT and gradient clipping (#13614) * yaml updates for LT and gradient clipping Signed-off-by: Paarth Neekhara * fix missing config Signed-off-by: Paarth Neekhara * missing config fix Signed-off-by: Paarth Neekhara --------- Signed-off-by: Paarth Neekhara --- examples/tts/conf/magpietts/magpietts_dc_en.yaml | 5 +++-- examples/tts/conf/magpietts/magpietts_en.yaml | 5 +++-- examples/tts/conf/magpietts/magpietts_inference_en.yaml | 5 +++-- .../conf/magpietts/magpietts_inference_multilingual_v1.yaml | 5 +++-- examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml | 5 +++-- examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml | 5 +++-- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml index d105c3875dbf..075fe26f67f9 100644 --- a/examples/tts/conf/magpietts/magpietts_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_dc_en.yaml @@ -50,10 +50,10 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "none" # "none", "autoregressive", "maskgit" + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 1 + local_transformer_n_layers: 3 local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 @@ -155,6 +155,7 @@ trainer: check_val_every_n_epoch: 1 num_sanity_val_steps: 0 benchmark: false + gradient_clip_val: 2.5 exp_manager: exp_dir: null diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index baea9245de2e..deeda0ed1e72 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -52,10 +52,10 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "none" # "none", "autoregressive", "maskgit" + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 1 + local_transformer_n_layers: 3 local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 @@ -171,6 +171,7 @@ trainer: check_val_every_n_epoch: 1 num_sanity_val_steps: 0 benchmark: false + gradient_clip_val: 2.5 exp_manager: exp_dir: null diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_inference_en.yaml index 96d33405868e..9b46b6cb42ee 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_en.yaml @@ -60,10 +60,10 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "none" # "none", "autoregressive", "maskgit" + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 1 + local_transformer_n_layers: 3 local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 @@ -163,6 +163,7 @@ trainer: val_check_interval: 500 # check_val_every_n_epoch: 10 benchmark: false + gradient_clip_val: 2.5 exp_manager: exp_dir: null diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml index bd2ff60575d8..10af938fed6f 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml @@ -60,10 +60,10 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "none" # "none", "autoregressive", "maskgit" + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 1 + local_transformer_n_layers: 3 local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 @@ -211,6 +211,7 @@ trainer: val_check_interval: 500 # check_val_every_n_epoch: 10 benchmark: false + gradient_clip_val: 2.5 exp_manager: exp_dir: null diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index c062e61ece64..d6b8c355b8e3 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -38,10 +38,10 @@ model: aligner_encoder_train_steps: 50_000 # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "none" # "none", "autoregressive", "maskgit" + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 1 + local_transformer_n_layers: 3 local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 @@ -165,6 +165,7 @@ trainer: num_sanity_val_steps: 0 benchmark: false use_distributed_sampler: false # required because Lhotse has its own handling + gradient_clip_val: 2.5 exp_manager: diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index 038cc301cded..2c958b369886 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -53,10 +53,10 @@ model: aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "none" # "none", "autoregressive", "maskgit" + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 1 + local_transformer_n_layers: 3 local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 @@ -219,6 +219,7 @@ trainer: val_check_interval: 500 # check_val_every_n_epoch: 10 benchmark: false + gradient_clip_val: 2.5 exp_manager: exp_dir: null From 7e510a6bdc5ff0f4c237db00018b84c08189a6de Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 28 May 2025 09:30:05 -0400 Subject: [PATCH 037/113] Add CI/CD to Magpie dev branch (#13682) * update cicd yaml Signed-off-by: Jason * update tests Signed-off-by: Jason * undo change Signed-off-by: Jason * fix path Signed-off-by: Jason * disable failing test; add headers Signed-off-by: Jason --------- Signed-off-by: Jason --- .github/workflows/cicd-main.yml | 3162 +++++++++-------- .../magpietts_preference_optimization.py | 19 +- scripts/magpietts/codec_extraction.py | 13 + scripts/magpietts/eval_squimmos.py | 18 +- scripts/magpietts/evalset_config.py | 17 +- scripts/magpietts/evaluate_generated_audio.py | 13 + scripts/magpietts/infer_and_evaluate.py | 15 +- scripts/tts_dataset_to_lhotse/create_shars.py | 13 + .../tts/modules/test_audio_codec_modules.py | 1 + .../L2_TTS_Fast_dev_runs_Magpietts_config1.sh | 33 + 10 files changed, 1720 insertions(+), 1584 deletions(-) create mode 100644 tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_config1.sh diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 540a0f9d99fb..86e7f16ea195 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -18,6 +18,7 @@ on: - main - r** - weekly-bump + - magpietts_2503 types: [labeled] push: branches: @@ -51,7 +52,7 @@ jobs: - name: Select tests to run id: test_to_run run: | - # For manual dispatch, we replace `all` with the actual job names + # For manual dispatch, we replace `all` with the actual job names if [[ "$EVENT_NAME" == "workflow_dispatch" && "$TESTS_TO_RUN" == "all" ]]; then TESTS_TO_RUN=$(cat .github/workflows/cicd-main.yml | yq '.jobs | [to_entries[] | .key] | join(",")') @@ -59,14 +60,14 @@ jobs: elif [[ "$EVENT_NAME" == "workflow_dispatch" && "$TESTS_TO_RUN" != "all" ]]; then TESTS_TO_RUN=$TESTS_TO_RUN - # For correctly labeled PR, we replace `all` with the actual job names + # For correctly labeled PR, we replace `all` with the actual job names elif [[ "$EVENT_NAME" == "pull_request" && "$HAS_LABEL" == "true" ]]; then TESTS_TO_RUN=$(cat .github/workflows/cicd-main.yml | yq '.jobs | [to_entries[] | .key] | join(",")') # For incorrectly labeled PR, run no tests elif [[ "$EVENT_NAME" == "pull_request" && "$HAS_LABEL" != "true" ]]; then TESTS_TO_RUN="" - + # For push events, run all tests. This is so that we can generate coverage # on branch `main`. elif [[ "$EVENT_NAME" == "push" ]]; then @@ -355,1595 +356,1606 @@ jobs: SCRIPT: |- RUN_ID=${{ github.run_id }} bash tests/functional_tests/L0_Setup_Test_Data_And_Models.sh - # L2: Community llava multimodal Checkpoints tests - L2_Community_vita_Checkpoints_tests_Llama3: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Community_vita_Checkpoints_tests_Llama3') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} - bash tests/functional_tests/L2_Community_vita_Checkpoints_tests_Llama3.sh + # # L2: Community llava multimodal Checkpoints tests + # L2_Community_vita_Checkpoints_tests_Llama3: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Community_vita_Checkpoints_tests_Llama3') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} + # bash tests/functional_tests/L2_Community_vita_Checkpoints_tests_Llama3.sh + + # # L2: ASR dev run + # ASR_dev_run_Speech_to_Text: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run_Speech_to_Text') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} + # bash tests/functional_tests/ASR_dev_run_Speech_to_Text.sh + + # ASR_dev_run_Speech_to_Text_WPE_CitriNet: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run_Speech_to_Text_WPE_CitriNet') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} + # bash tests/functional_tests/ASR_dev_run_Speech_to_Text_WPE_CitriNet.sh + + # ASR_dev_run_Speech_Pre-training_-_CitriNet: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run_Speech_Pre-training_-_CitriNet') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/ASR_dev_run_Speech_Pre-training_-_CitriNet.sh + + # Optional_ASR_dev_run_Speech_To_Text_Finetuning: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'Optional_ASR_dev_run_Speech_To_Text_Finetuning') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/Optional_ASR_dev_run_Speech_To_Text_Finetuning.sh + # IS_OPTIONAL: true + + # Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning.sh + # IS_OPTIONAL: true + + # ASR_dev_run_Speech_to_Text_WPE_-_Conformer: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run_Speech_to_Text_WPE_-_Conformer') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/ASR_dev_run_Speech_to_Text_WPE_-_Conformer.sh + + # # L2: ASR dev run - part two + # ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer.sh + + # L2_Speech_to_Text_EMA: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_to_Text_EMA') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_to_Text_EMA.sh + + # L2_Speech_to_Text_AED: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_to_Text_AED') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_to_Text_AED.sh + + # # L2: Speaker dev run + # L2_Speaker_dev_run_Speaker_Recognition: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Speaker_Recognition') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Speaker_Recognition.sh + + # L2_Speaker_dev_run_Speaker_Diarization: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Speaker_Diarization') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Speaker_Diarization.sh + + # L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer.sh + + # L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference.sh + + # L2_Speaker_dev_run_Speech_to_Label: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Speech_to_Label') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Speech_to_Label.sh + + # L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference.sh + + # L2_Speaker_dev_run_Clustering_Diarizer_Inference: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Clustering_Diarizer_Inference') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Clustering_Diarizer_Inference.sh + + # L2_Speaker_dev_run_Neural_Diarizer_Inference: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Neural_Diarizer_Inference') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Neural_Diarizer_Inference.sh + + # L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation.sh + + # # L2: ASR Multi-dataloader dev run + # L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader.sh + + # L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader.sh + + # # L2: ASR Adapters + # L2_ASR_Adapters_Linear_Adapters: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_ASR_Adapters_Linear_Adapters') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_ASR_Adapters_Linear_Adapters.sh + + # L2_ASR_Adapters_RelPos_MHA_Adapters: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_ASR_Adapters_RelPos_MHA_Adapters') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_ASR_Adapters_RelPos_MHA_Adapters.sh + + # # L2: OOMptimizer + # L2_Speech_Estimate_Duration_Bins: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Estimate_Duration_Bins') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Estimate_Duration_Bins.sh + + # # L2: OOMptimizer + # L2_Speech_Batch_Size_OOMptimizer: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Batch_Size_OOMptimizer') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Batch_Size_OOMptimizer.sh + + # # L2: OOMptimizer Canary (has a different batch schema) + # Optional_L2_Speech_Batch_Size_OOMptimizer_Canary: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'Optional_L2_Speech_Batch_Size_OOMptimizer_Canary') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/Optional_L2_Speech_Batch_Size_OOMptimizer_Canary.sh + # IS_OPTIONAL: true + + # # L2: Speech Transcription + # L2_Speech_Transcription_Speech_to_Text_Transcribe: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Transcription_Speech_to_Text_Transcribe') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Transcription_Speech_to_Text_Transcribe.sh + + # # L2: Speech Transcription + # L2_Speech_Transcription_Canary_Transcribe_Full_Manifest: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Transcription_Canary_Transcribe_Full_Manifest') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Transcription_Canary_Transcribe_Full_Manifest.sh + # AFTER_SCRIPT: | + # rm -rf /tmp/preds.json transcribe.log + + # L2_Speech_Transcription_Canary_Transcribe_With_Prompt: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Transcription_Canary_Transcribe_With_Prompt') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Transcription_Canary_Transcribe_With_Prompt.sh + # AFTER_SCRIPT: | + # rm -rf preds.json transcribe.log + + # L2_Speech_Transcription_Canary_Transcribe_Audio_Dir: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Transcription_Canary_Transcribe_Audio_Dir') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Transcription_Canary_Transcribe_Audio_Dir.sh + # AFTER_SCRIPT: | + # rm -rf preds.json + # IS_OPTIONAL: true + + # # L2: Segmentation Tool + # L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav.sh + + # L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3.sh + + # # L2: G2P Models + # L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference.sh + + # # TODO: pleasefixme @redoctopus + # # - name: ByT5G2P training, evaluation and inference + # # run: | + # # cd examples/tts/g2p && \ + # # TIME=`date +"%Y-%m-%d-%T"` && OUTPUT_DIR_T5=output_byt5_${TIME} && \ + # # python g2p_train_and_evaluate.py \ + # # train_manifest=/home/TestData/g2p/g2p.json \ + # # validation_manifest=/home/TestData/g2p/g2p.json \ + # # model.test_ds.manifest_filepath=/home/TestData/g2p/g2p.json \ + # # trainer.max_epochs=1 \ + # # model.max_source_len=64 \ + # # trainer.devices=1 \ + # # do_training=True \ + # # do_testing=True \ + # # exp_manager.exp_dir=${OUTPUT_DIR_T5} \ + # # +exp_manager.use_datetime_version=False\ + # # +exp_manager.version=test && \ + # # python g2p_inference.py \ + # # pretrained_model=${OUTPUT_DIR_T5}/T5G2P/test/checkpoints/T5G2P.nemo \ + # # manifest_filepath=/home/TestData/g2p/g2p.json \ + # # phoneme_field=text + # # } + # # } + # # - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + # # if: "failure()" + + # L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference.sh + + # # TODO: remove +model.optim.capturable=True when Pytorch fix: https://github.com/pytorch/pytorch/pull/81858 + # # is in the release container + # # L2: NMT Attention is All You Need Training + # L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Post-LN: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Post-LN') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Post-LN.sh + # AFTER_SCRIPT: | + # rm -rf examples/nlp/machine_translation/nmt_results + # L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Pre-LN: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Pre-LN') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Pre-LN.sh + + # L2_NMT_Attention_is_All_You_Need_Training_NMT_Multi-Validation: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Training_NMT_Multi-Validation') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Training_NMT_Multi-Validation.sh + + # # L2: NMT Attention is All You Need Inference + # L2_NMT_Attention_is_All_You_Need_Inference: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Inference') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Inference.sh + + # # L2: NMT Attention is All You Need Finetuning + # L2_NMT_Attention_is_All_You_Need_Finetuning: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Finetuning') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Finetuning.sh + # AFTER_SCRIPT: | + # rm -rf examples/nlp/machine_translation/nmt_finetune + + # # L2: NMT Tarred Dataset Creation + # L2_NMT_Tarred_Dataset_Creation_Auto_Tarred_Dataset_Creation: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Tarred_Dataset_Creation_Auto_Tarred_Dataset_Creation') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Tarred_Dataset_Creation_Auto_Tarred_Dataset_Creation.sh + + # L2_NMT_Tarred_Dataset_Creation_Script_Tarred_Dataset_Creation: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Tarred_Dataset_Creation_Script_Tarred_Dataset_Creation') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Tarred_Dataset_Creation_Script_Tarred_Dataset_Creation.sh + + # L2_Megatron_NMT_Training_TP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Megatron_NMT_Training_TP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Megatron_NMT_Training_TP2.sh + # AFTER_SCRIPT: | + # rm -rf examples/nlp/machine_translation/megatron_nmt_results + + # L2_VLM_HF_Transformer_PEFT: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_VLM_HF_Transformer_PEFT') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_VLM_HF_Transformer_PEFT.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_VLM_HF_Transformer_PEFT_FSDP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_VLM_HF_Transformer_PEFT_FSDP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_VLM_HF_Transformer_PEFT_FSDP2.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_VLM_HF_Transformer_PEFT_4bit: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_VLM_HF_Transformer_PEFT_4bit') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_VLM_HF_Transformer_PEFT_4bit.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_VLM_HF_Transformer_SFT_FSDP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_VLM_HF_Transformer_SFT_FSDP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_VLM_HF_Transformer_SFT_FSDP2.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_PEFT_notebook: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_notebook') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_notebook.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_PEFT: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_PEFT_nemorun: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_nemorun') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_nemorun.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_PEFT_2gpu: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_2gpu.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_PEFT_2gpu_FSDP2_liger: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu_FSDP2_liger') || needs.pre-flight.outputs.all == 'true' + # with: + # RUNNER: self-hosted-azure + # SCRIPT: | + # TRANSFORMERS_OFFLINE=1 HF_HOME=/home/TestData/automodel/hf_home python examples/llm/peft/automodel.py \ + # --model /home/TestData/akoumparouli/hf_mixtral_2l/ \ + # --max-steps 3 \ + # --devices 2 \ + # --strategy fsdp2 --liger + + + # L2_HF_Transformer_PEFT_2gpu_FSDP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu_FSDP2') || needs.pre-flight.outputs.all == 'true' + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_2gpu_FSDP2.sh - # L2: ASR dev run - ASR_dev_run_Speech_to_Text: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run_Speech_to_Text') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} - bash tests/functional_tests/ASR_dev_run_Speech_to_Text.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments - ASR_dev_run_Speech_to_Text_WPE_CitriNet: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run_Speech_to_Text_WPE_CitriNet') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} - bash tests/functional_tests/ASR_dev_run_Speech_to_Text_WPE_CitriNet.sh + # L2_HF_Transformer_PEFT_2gpu_nemorun: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu_nemorun') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_2gpu_nemorun.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_SFT_2gpu: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_2gpu.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_SFT_2gpu_FSDP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu_FSDP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_2gpu_FSDP2.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_SFT_2gpu_nemorun: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu_nemorun') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_2gpu_nemorun.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_SFT_FSDP2_2gpu: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_FSDP2_2gpu') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_FSDP2_2gpu.sh - ASR_dev_run_Speech_Pre-training_-_CitriNet: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run_Speech_Pre-training_-_CitriNet') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/ASR_dev_run_Speech_Pre-training_-_CitriNet.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments - Optional_ASR_dev_run_Speech_To_Text_Finetuning: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'Optional_ASR_dev_run_Speech_To_Text_Finetuning') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/Optional_ASR_dev_run_Speech_To_Text_Finetuning.sh - IS_OPTIONAL: true + # L2_HF_Transformer_PT_2gpu: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT_2gpu') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT_2gpu.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_PT_2gpu_nemorun: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT_2gpu_nemorun') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT_2gpu_nemorun.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_PT: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_PT_nemorun: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT_nemorun') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT_nemorun.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_SFT_notebook: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_notebook') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_notebook.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_SFT: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_SFT_nemorun: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_nemorun') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_nemorun.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # L2_HF_Transformer_SFT_TE_Acceleration: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_TE_Acceleration') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_TE_Acceleration.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + # IS_OPTIONAL: true + + # L2_HF_Transformer_PT_TE_Acceleration: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT_TE_Acceleration') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT_TE_Acceleration.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # # L2: SpeechLM tests + # L2_HF_Transformer_SpeechLM_SFT_2gpu: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SpeechLM_SFT_2gpu') || needs.pre-flight.outputs.all == 'true' + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SpeechLM_SFT_2gpu.sh + # AFTER_SCRIPT: | + # rm -rf nemo_experiments + + # # L2: TTS Fast dev runs 1 + # L2_TTS_Fast_dev_runs_1_Tacotron_2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_1_Tacotron_2') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_1_Tacotron_2.sh + + # L2_TTS_Fast_dev_runs_1_WaveGlow: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_1_WaveGlow') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_1_WaveGlow.sh + + # L2_TTS_Fast_dev_runs_1_FastPitch: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_1_FastPitch') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_1_FastPitch.sh + + # # OPTIONAL_L2_TTS_Fast_dev_runs_1_RADTTS: + # # needs: [pre-flight, cicd-test-container-build] + # # runs-on: self-hosted-azure + # # timeout-minutes: 10 + # # container: + # # image: nemoci.azurecr.io/nemo_container:${{ github.run_id }} + # # options: + # # # --user 0:128 + # # --device=/dev/nvidia0 + # # --gpus all + # # --shm-size=8g + # # --env TRANSFORMERS_OFFLINE=0 + # # --env HYDRA_FULL_ERROR=1 + # # --volume /mnt/datadrive/TestData:/home/TestData + # # steps: + # # - name: Checkout repository + # # uses: actions/checkout@v4 + # # - run: | + # # python examples/tts/radtts.py \ + # # train_dataset=/home/TestData/an4_dataset/an4_train.json \ + # # validation_datasets=/home/TestData/an4_dataset/an4_val.json \ + # # sup_data_path=/home/TestData/an4_dataset/radtts_beta_priors \ + # # trainer.devices="[0]" \ + # # +trainer.limit_train_batches=1 \ + # # +trainer.limit_val_batches=1 \ + # # trainer.max_epochs=1 \ + # # trainer.strategy=auto \ + # # model.pitch_mean=212.35873413085938 \ + # # model.pitch_std=68.52806091308594 \ + # # model.train_ds.dataloader_params.batch_size=4 \ + # # model.train_ds.dataloader_params.num_workers=0 \ + # # model.validation_ds.dataloader_params.batch_size=4 \ + # # model.validation_ds.dataloader_params.num_workers=0 \ + # # export_dir=/home/TestData/radtts_test \ + # # model.optim.lr=0.0001 \ + # # model.modelConfig.decoder_use_partial_padding=True \ + # # ~trainer.check_val_every_n_epoch \ + # # ~model.text_normalizer \ + # # ~model.text_normalizer_call_kwargs + # # #- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + # # # if: "failure()" + + # L2_TTS_Fast_dev_runs_1_Hifigan: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_1_Hifigan') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_1_Hifigan.sh + + L2_TTS_Fast_dev_runs_Magpietts_config1: + needs: [pre-flight, cicd-test-container-build] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_Magpietts_config1') + with: + RUNNER: self-hosted-azure + SCRIPT: |- + RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_config1.sh + + # # L2: NeRF + # # L2_NeRF_DreamFusion: + # # needs: [pre-flight, cicd-test-container-build] + # # runs-on: self-hosted-azure + # # container: + # # image: nemoci.azurecr.io/nemo_container:${{ github.run_id }} + # # options: + # # # --user 0:128 + # # --device=/dev/nvidia0 + # # --gpus all + # # --shm-size=8g + # # --env TRANSFORMERS_OFFLINE=0 + # # --env HYDRA_FULL_ERROR=1 + # # --volume /mnt/datadrive/TestData:/home/TestData + # # steps: + # # - name: Checkout repository + # # uses: actions/checkout@v4 + # # - run: | + # # python examples/multimodal/text_to_image/nerf/main.py \ + # # trainer.num_nodes=1 \ + # # trainer.devices="[0]" \ + # # trainer.max_steps=1000 \ + # # model.prompt="a DSLR photo of a delicious hamburger" \ + # # exp_manager.exp_dir=examples/multimodal/text_to_image/nerf/dreamfusion_results + # # + # # rm -rf examples/multimodal/text_to_image/nerf/dreamfusion_results + # # - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + # # if: "failure()" + + # Speech_Checkpoints_tests: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'Speech_Checkpoints_tests') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # TIMEOUT: 20 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/Speech_Checkpoints_tests.sh + # AFTER_SCRIPT: | + # rm -f examples/asr/evaluation_transcripts.json + + # L2_Stable_Diffusion_Training: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Stable_Diffusion_Training') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Stable_Diffusion_Training.sh + # AFTER_SCRIPT: | + # rm -rf examples/multimodal/text_to_image/sd_train_results + + # L2_NeMo_2_GPT_Pretraining_no_transformer_engine: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_Pretraining_no_transformer_engine') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_Pretraining_no_transformer_engine.sh + # AFTER_SCRIPT: | + # rm -rf tests/collections/llm/gpt_pretrain_results + # rm -rf tests/collections/llm/gpt_index_mappings + + # L2_NeMo_2_llama3_pretraining_recipe: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_llama3_pretraining_recipe') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_llama3_pretraining_recipe.sh + + # L2_NeMo_2_llama3_fault_tolerance_plugin: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_llama3_fault_tolerance_plugin') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_llama3_fault_tolerance_plugin.sh + + # L2_NeMo_2_llama3_straggler_detection: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_llama3_straggler_detection') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_llama3_straggler_detection.sh + + # L2_NeMo_2_GPT_DDP_Param_Parity_check: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_DDP_Param_Parity_check') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_DDP_Param_Parity_check.sh + + # AFTER_SCRIPT: | + # rm -rf tests/collections/llm/gpt_pretrain_results + # rm -rf tests/collections/llm/gpt_index_mappings + + # L2_NeMo_2_Hyena_DDP_Pretraining_Test: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Hyena_DDP_Pretraining_Test') + # with: + # RUNNER: self-hosted-azure # Assume runner has 2 GPUs + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Hyena_DDP_Pretraining_Test.sh - Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning.sh - IS_OPTIONAL: true + # AFTER_SCRIPT: | + # rm -rf tests/collections/llm/hyena_pretrain_results/${{ github.run_id }} - ASR_dev_run_Speech_to_Text_WPE_-_Conformer: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run_Speech_to_Text_WPE_-_Conformer') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/ASR_dev_run_Speech_to_Text_WPE_-_Conformer.sh + # L2_NeMo_2_SSM_Pretraining: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_SSM_Pretraining') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_SSM_Pretraining.sh + + # L2_NeMo_2_SSM_Finetuning: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_SSM_Finetuning') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_SSM_Finetuning.sh + + # L2_NeMo_2_HF_MODEL_IMPORT: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_HF_MODEL_IMPORT') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_HF_MODEL_IMPORT.sh - # L2: ASR dev run - part two - ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer.sh + # AFTER_SCRIPT: | + # rm -rf ~/.cache/nemo/models - L2_Speech_to_Text_EMA: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_to_Text_EMA') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_to_Text_EMA.sh + # L2_NeMo_2_jit_callback: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_jit_callback') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_jit_callback.sh + + # L2_NeMo_2_T5_Pretraining: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_T5_Pretraining') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_T5_Pretraining.sh + # AFTER_SCRIPT: | + # rm -rf tests/collections/llm/t5_pretrain_results/${{ github.run_id }} + # rm -rf tests/collections/llm/t5_index_mappings/${{ github.run_id }} + + # L2_NeMo_2_T5_Finetuning: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_T5_Finetuning') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_T5_Finetuning.sh + # AFTER_SCRIPT: | + # rm -rf tests/collections/llm/t5_finetune_results/${{ github.run_id }} + + # L2_NeMo_2_T5_LoRA: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_T5_LoRA') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_T5_LoRA.sh + # AFTER_SCRIPT: | + # rm -rf tests/collections/llm/t5_peft_results/${{ github.run_id }} + + # L2_NeMo_2_NEVA_MOCK_PRETRAIN_TP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PRETRAIN_TP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_PRETRAIN_TP2.sh + + # L2_NeMo_2_NEVA_MOCK_PRETRAIN_PP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PRETRAIN_PP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_PRETRAIN_PP2.sh + + # L2_NeMo_2_NEVA_MOCK_PRETRAIN_CP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PRETRAIN_CP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_PRETRAIN_CP2.sh + + # L2_NeMo_2_NEVA_MOCK_FINETUNE_TP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_FINETUNE_TP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_FINETUNE_TP2.sh + + # L2_NeMo_2_NEVA_MOCK_FINETUNE_PP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_FINETUNE_PP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_FINETUNE_PP2.sh + + # L2_NeMo_2_NEVA_MOCK_FINETUNE_CP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_FINETUNE_CP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_FINETUNE_CP2.sh + + # L2_NeMo_2_NEVA_LOAD_GENERATE: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_LOAD_GENERATE') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_LOAD_GENERATE.sh + + # L2_NEMO_2_MLLAMA_Inference: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NEMO_2_MLLAMA_Inference') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NEMO_2_MLLAMA_Inference.sh + + # L2_NeMo_2_MLLAMA_MOCK_FINETUNE_TP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_MLLAMA_MOCK_FINETUNE_TP2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_MLLAMA_MOCK_FINETUNE_TP2.sh + + # L2_NeMo_2_Mixtral_Pretraining: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_Pretraining') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_Pretraining.sh + + # L2_NeMo_2_GPT_SFT_TP1PP1_MBS1: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP1_MBS1') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP1PP1_MBS1.sh + + # L2_NeMo_2_GPT_SFT_TP1PP1_MBS2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP1_MBS2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP1PP1_MBS2.sh + + # L2_NeMo_2_GPT_SFT_TP1PP2_MBS2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP2_MBS2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP1PP2_MBS2.sh + + # L2_NeMo_2_GPT_SFT_TP2PP1_MBS2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP2PP1_MBS2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP2PP1_MBS2.sh + + # L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED.sh + + # L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1.sh + + # L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2.sh + + # L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2.sh + + # L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2.sh + + # L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED.sh + + # L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED.sh + + # L2_NeMo_2_GPT_CLoRA_TP1PP1_MBS1_PACKED: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_CLoRA_TP1PP1_MBS1_PACKED') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_CLoRA_TP1PP1_MBS1_PACKED.sh + + # L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat.sh + + # L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2_exclude: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2_exclude') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2_exclude.sh + + # L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2.sh + + # L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1.sh + + # L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1.sh + + # L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1.sh + + # L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1_exclude: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1_exclude') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1_exclude.sh + + # L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1.sh + + # L2_NEMO_2_LoRA_MERGE: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NEMO_2_LoRA_MERGE') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NEMO_2_LoRA_MERGE.sh + + # L2_NEMO_2_LoRA_Export: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NEMO_2_LoRA_Export') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NEMO_2_LoRA_Export.sh + + # L2_NEMO_2_LoRA_Inference: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NEMO_2_LoRA_Inference') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NEMO_2_LoRA_Inference.sh + + # L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact.sh + + # L2_NeMo_2_PTQ_Llama2_FP8_trtllm: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_PTQ_Llama2_FP8_trtllm') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_PTQ_Llama2_FP8_trtllm.sh + + # AFTER_SCRIPT: | + # rm -rf /tmp/nemo2_ckpt + # rm -rf /tmp/nemo2_ptq_engine + + # L2_NeMo_2_PTQ_Llama2_FP8_nemo: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_PTQ_Llama2_FP8_nemo') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_PTQ_Llama2_FP8_nemo.sh + + # AFTER_SCRIPT: | + # rm -rf /tmp/nemo2_ckpt + # rm -rf /tmp/nemo2_ptq_ckpt + + # L2_NeMo_2_Distill_Llama3_TP1PP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Distill_Llama3_TP1PP2') || needs.pre-flight.outputs.all == 'true' + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Distill_Llama3_TP1PP2.sh + + # AFTER_SCRIPT: | + # rm -rf /tmp/nemo2_ckpt + # rm -rf /tmp/distill_logs + + # L2_NeMo_2_Prune_Llama_TP1PP2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Prune_Llama_TP1PP2') || needs.pre-flight.outputs.all == 'true' + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Prune_Llama_TP1PP2.sh + # AFTER_SCRIPT: | + # rm -rf /tmp/nemo2_ckpt /tmp/pruned-llama + + # L2_NeMo_2_Export_In_Framework: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Export_In_Framework') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Export_In_Framework.sh - L2_Speech_to_Text_AED: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_to_Text_AED') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_to_Text_AED.sh + # AFTER_SCRIPT: | + # rm -rf /tmp/nemo2_ckpt /tmp/lambada.json - # L2: Speaker dev run - L2_Speaker_dev_run_Speaker_Recognition: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Speaker_Recognition') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Speaker_Recognition.sh + # L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING.sh - L2_Speaker_dev_run_Speaker_Diarization: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Speaker_Diarization') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Speaker_Diarization.sh + # AFTER_SCRIPT: | + # rm -rf /tmp/nemo2_llava_next_results - L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer.sh + # L2_NeMo_2_VLLM_EXPORT: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_VLLM_EXPORT') + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_VLLM_EXPORT.sh + + # AFTER_SCRIPT: | + # rm -rf /tmp/llama_head64 + # rm -rf /tmp/nemo2_ckpt + # rm -rf /tmp/vllm_from_nemo2 + + # L2_NeMo_2_EVAL: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_EVAL') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_EVAL.sh - L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference.sh + # AFTER_SCRIPT: | + # rm -rf /tmp/trtllm_dir - L2_Speaker_dev_run_Speech_to_Label: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Speech_to_Label') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Speech_to_Label.sh + # L2_NeMo_2_Auto_Configurator_TP1_PP1_MBS124: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Auto_Configurator_TP1_PP1_MBS124') + # with: + # RUNNER: self-hosted-azure-gpus-1 + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Auto_Configurator_TP1_PP1_MBS124.sh + # AFTER_SCRIPT: | + # rm -rf examples/llm/auto_configurator/auto_conf_logs + + # L2_SpeechLM_LoRA_TP1PP1_MBS2: + # needs: [pre-flight, cicd-test-container-build] + # uses: ./.github/workflows/_test_template.yml + # if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_SpeechLM_LoRA_TP1PP1_MBS2') || needs.pre-flight.outputs.all == 'true' + # with: + # RUNNER: self-hosted-azure + # SCRIPT: |- + # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_SpeechLM_LoRA_TP1PP1_MBS2.sh - L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference.sh + # AFTER_SCRIPT: | + # rm -rf /tmp/nemo2_speechlm_lora/${{ github.run_id }} - L2_Speaker_dev_run_Clustering_Diarizer_Inference: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Clustering_Diarizer_Inference') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Clustering_Diarizer_Inference.sh + Nemo_CICD_Test: + needs: + - pre-flight + - cicd-import-tests - L2_Speaker_dev_run_Neural_Diarizer_Inference: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Neural_Diarizer_Inference') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Neural_Diarizer_Inference.sh + - L0_Unit_Tests_GPU_ASR + - L0_Unit_Tests_GPU_Audio + - L0_Unit_Tests_GPU_Common + - L0_Unit_Tests_GPU_LLM + - L0_Unit_Tests_GPU_Multimodal + - L0_Unit_Tests_GPU_TTS + - L0_Unit_Tests_GPU_Core + - L0_Unit_Tests_GPU_Hydra + - L0_Unit_Tests_GPU_Lightning + - L0_Unit_Tests_GPU_Others - L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation.sh + - L0_Unit_Tests_CPU_ASR + - L0_Unit_Tests_CPU_Audio + - L0_Unit_Tests_CPU_Common + - L0_Unit_Tests_CPU_LLM + - L0_Unit_Tests_CPU_Multimodal + - L0_Unit_Tests_CPU_TTS + - L0_Unit_Tests_CPU_Core + - L0_Unit_Tests_CPU_Hydra + - L0_Unit_Tests_CPU_Lightning + - L0_Unit_Tests_CPU_Others - # L2: ASR Multi-dataloader dev run - L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader.sh - - L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader.sh - - # L2: ASR Adapters - L2_ASR_Adapters_Linear_Adapters: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_ASR_Adapters_Linear_Adapters') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_ASR_Adapters_Linear_Adapters.sh - - L2_ASR_Adapters_RelPos_MHA_Adapters: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_ASR_Adapters_RelPos_MHA_Adapters') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_ASR_Adapters_RelPos_MHA_Adapters.sh - - # L2: OOMptimizer - L2_Speech_Estimate_Duration_Bins: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Estimate_Duration_Bins') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Estimate_Duration_Bins.sh - - # L2: OOMptimizer - L2_Speech_Batch_Size_OOMptimizer: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Batch_Size_OOMptimizer') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Batch_Size_OOMptimizer.sh - - # L2: OOMptimizer Canary (has a different batch schema) - Optional_L2_Speech_Batch_Size_OOMptimizer_Canary: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'Optional_L2_Speech_Batch_Size_OOMptimizer_Canary') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/Optional_L2_Speech_Batch_Size_OOMptimizer_Canary.sh - IS_OPTIONAL: true - - # L2: Speech Transcription - L2_Speech_Transcription_Speech_to_Text_Transcribe: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Transcription_Speech_to_Text_Transcribe') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Transcription_Speech_to_Text_Transcribe.sh - - # L2: Speech Transcription - L2_Speech_Transcription_Canary_Transcribe_Full_Manifest: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Transcription_Canary_Transcribe_Full_Manifest') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Transcription_Canary_Transcribe_Full_Manifest.sh - AFTER_SCRIPT: | - rm -rf /tmp/preds.json transcribe.log - - L2_Speech_Transcription_Canary_Transcribe_With_Prompt: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Transcription_Canary_Transcribe_With_Prompt') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Transcription_Canary_Transcribe_With_Prompt.sh - AFTER_SCRIPT: | - rm -rf preds.json transcribe.log - - L2_Speech_Transcription_Canary_Transcribe_Audio_Dir: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Speech_Transcription_Canary_Transcribe_Audio_Dir') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Speech_Transcription_Canary_Transcribe_Audio_Dir.sh - AFTER_SCRIPT: | - rm -rf preds.json - IS_OPTIONAL: true - - # L2: Segmentation Tool - L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav.sh - - L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3.sh - - # L2: G2P Models - L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference.sh - - # TODO: pleasefixme @redoctopus - # - name: ByT5G2P training, evaluation and inference - # run: | - # cd examples/tts/g2p && \ - # TIME=`date +"%Y-%m-%d-%T"` && OUTPUT_DIR_T5=output_byt5_${TIME} && \ - # python g2p_train_and_evaluate.py \ - # train_manifest=/home/TestData/g2p/g2p.json \ - # validation_manifest=/home/TestData/g2p/g2p.json \ - # model.test_ds.manifest_filepath=/home/TestData/g2p/g2p.json \ - # trainer.max_epochs=1 \ - # model.max_source_len=64 \ - # trainer.devices=1 \ - # do_training=True \ - # do_testing=True \ - # exp_manager.exp_dir=${OUTPUT_DIR_T5} \ - # +exp_manager.use_datetime_version=False\ - # +exp_manager.version=test && \ - # python g2p_inference.py \ - # pretrained_model=${OUTPUT_DIR_T5}/T5G2P/test/checkpoints/T5G2P.nemo \ - # manifest_filepath=/home/TestData/g2p/g2p.json \ - # phoneme_field=text - # } - # } - # - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - # if: "failure()" - - L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference.sh - - # TODO: remove +model.optim.capturable=True when Pytorch fix: https://github.com/pytorch/pytorch/pull/81858 - # is in the release container - # L2: NMT Attention is All You Need Training - L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Post-LN: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Post-LN') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Post-LN.sh - AFTER_SCRIPT: | - rm -rf examples/nlp/machine_translation/nmt_results - L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Pre-LN: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Pre-LN') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Pre-LN.sh - - L2_NMT_Attention_is_All_You_Need_Training_NMT_Multi-Validation: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Training_NMT_Multi-Validation') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Training_NMT_Multi-Validation.sh - - # L2: NMT Attention is All You Need Inference - L2_NMT_Attention_is_All_You_Need_Inference: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Inference') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Inference.sh - - # L2: NMT Attention is All You Need Finetuning - L2_NMT_Attention_is_All_You_Need_Finetuning: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Attention_is_All_You_Need_Finetuning') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Attention_is_All_You_Need_Finetuning.sh - AFTER_SCRIPT: | - rm -rf examples/nlp/machine_translation/nmt_finetune - - # L2: NMT Tarred Dataset Creation - L2_NMT_Tarred_Dataset_Creation_Auto_Tarred_Dataset_Creation: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Tarred_Dataset_Creation_Auto_Tarred_Dataset_Creation') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Tarred_Dataset_Creation_Auto_Tarred_Dataset_Creation.sh - - L2_NMT_Tarred_Dataset_Creation_Script_Tarred_Dataset_Creation: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NMT_Tarred_Dataset_Creation_Script_Tarred_Dataset_Creation') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NMT_Tarred_Dataset_Creation_Script_Tarred_Dataset_Creation.sh - - L2_Megatron_NMT_Training_TP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Megatron_NMT_Training_TP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Megatron_NMT_Training_TP2.sh - AFTER_SCRIPT: | - rm -rf examples/nlp/machine_translation/megatron_nmt_results - - L2_VLM_HF_Transformer_PEFT: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_VLM_HF_Transformer_PEFT') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_VLM_HF_Transformer_PEFT.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_VLM_HF_Transformer_PEFT_FSDP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_VLM_HF_Transformer_PEFT_FSDP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_VLM_HF_Transformer_PEFT_FSDP2.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_VLM_HF_Transformer_PEFT_4bit: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_VLM_HF_Transformer_PEFT_4bit') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_VLM_HF_Transformer_PEFT_4bit.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_VLM_HF_Transformer_SFT_FSDP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_VLM_HF_Transformer_SFT_FSDP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_VLM_HF_Transformer_SFT_FSDP2.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PEFT_notebook: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_notebook') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_notebook.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PEFT: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PEFT_nemorun: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_nemorun') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_nemorun.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PEFT_2gpu: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_2gpu.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PEFT_2gpu_FSDP2_liger: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu_FSDP2_liger') || needs.pre-flight.outputs.all == 'true' - with: - RUNNER: self-hosted-azure - SCRIPT: | - TRANSFORMERS_OFFLINE=1 HF_HOME=/home/TestData/automodel/hf_home python examples/llm/peft/automodel.py \ - --model /home/TestData/akoumparouli/hf_mixtral_2l/ \ - --max-steps 3 \ - --devices 2 \ - --strategy fsdp2 --liger - - - L2_HF_Transformer_PEFT_2gpu_FSDP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu_FSDP2') || needs.pre-flight.outputs.all == 'true' - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_2gpu_FSDP2.sh - - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PEFT_2gpu_nemorun: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu_nemorun') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PEFT_2gpu_nemorun.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_SFT_2gpu: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_2gpu.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_SFT_2gpu_FSDP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu_FSDP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_2gpu_FSDP2.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_SFT_2gpu_nemorun: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu_nemorun') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_2gpu_nemorun.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_SFT_FSDP2_2gpu: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_FSDP2_2gpu') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_FSDP2_2gpu.sh - - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PT_2gpu: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT_2gpu') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT_2gpu.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PT_2gpu_nemorun: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT_2gpu_nemorun') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT_2gpu_nemorun.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PT: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_PT_nemorun: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT_nemorun') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT_nemorun.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_SFT_notebook: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_notebook') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_notebook.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_SFT: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_SFT_nemorun: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_nemorun') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_nemorun.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - L2_HF_Transformer_SFT_TE_Acceleration: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SFT_TE_Acceleration') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SFT_TE_Acceleration.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - IS_OPTIONAL: true - - L2_HF_Transformer_PT_TE_Acceleration: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_PT_TE_Acceleration') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_PT_TE_Acceleration.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - # L2: SpeechLM tests - L2_HF_Transformer_SpeechLM_SFT_2gpu: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SpeechLM_SFT_2gpu') || needs.pre-flight.outputs.all == 'true' - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_HF_Transformer_SpeechLM_SFT_2gpu.sh - AFTER_SCRIPT: | - rm -rf nemo_experiments - - # L2: TTS Fast dev runs 1 - L2_TTS_Fast_dev_runs_1_Tacotron_2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_1_Tacotron_2') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_1_Tacotron_2.sh - - L2_TTS_Fast_dev_runs_1_WaveGlow: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_1_WaveGlow') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_1_WaveGlow.sh - - L2_TTS_Fast_dev_runs_1_FastPitch: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_1_FastPitch') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_1_FastPitch.sh - - # OPTIONAL_L2_TTS_Fast_dev_runs_1_RADTTS: - # needs: [pre-flight, cicd-test-container-build] - # runs-on: self-hosted-azure - # timeout-minutes: 10 - # container: - # image: nemoci.azurecr.io/nemo_container:${{ github.run_id }} - # options: - # # --user 0:128 - # --device=/dev/nvidia0 - # --gpus all - # --shm-size=8g - # --env TRANSFORMERS_OFFLINE=0 - # --env HYDRA_FULL_ERROR=1 - # --volume /mnt/datadrive/TestData:/home/TestData - # steps: - # - name: Checkout repository - # uses: actions/checkout@v4 - # - run: | - # python examples/tts/radtts.py \ - # train_dataset=/home/TestData/an4_dataset/an4_train.json \ - # validation_datasets=/home/TestData/an4_dataset/an4_val.json \ - # sup_data_path=/home/TestData/an4_dataset/radtts_beta_priors \ - # trainer.devices="[0]" \ - # +trainer.limit_train_batches=1 \ - # +trainer.limit_val_batches=1 \ - # trainer.max_epochs=1 \ - # trainer.strategy=auto \ - # model.pitch_mean=212.35873413085938 \ - # model.pitch_std=68.52806091308594 \ - # model.train_ds.dataloader_params.batch_size=4 \ - # model.train_ds.dataloader_params.num_workers=0 \ - # model.validation_ds.dataloader_params.batch_size=4 \ - # model.validation_ds.dataloader_params.num_workers=0 \ - # export_dir=/home/TestData/radtts_test \ - # model.optim.lr=0.0001 \ - # model.modelConfig.decoder_use_partial_padding=True \ - # ~trainer.check_val_every_n_epoch \ - # ~model.text_normalizer \ - # ~model.text_normalizer_call_kwargs - # #- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - # # if: "failure()" - - L2_TTS_Fast_dev_runs_1_Hifigan: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_1_Hifigan') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_1_Hifigan.sh - - # L2: NeRF - # L2_NeRF_DreamFusion: - # needs: [pre-flight, cicd-test-container-build] - # runs-on: self-hosted-azure - # container: - # image: nemoci.azurecr.io/nemo_container:${{ github.run_id }} - # options: - # # --user 0:128 - # --device=/dev/nvidia0 - # --gpus all - # --shm-size=8g - # --env TRANSFORMERS_OFFLINE=0 - # --env HYDRA_FULL_ERROR=1 - # --volume /mnt/datadrive/TestData:/home/TestData - # steps: - # - name: Checkout repository - # uses: actions/checkout@v4 - # - run: | - # python examples/multimodal/text_to_image/nerf/main.py \ - # trainer.num_nodes=1 \ - # trainer.devices="[0]" \ - # trainer.max_steps=1000 \ - # model.prompt="a DSLR photo of a delicious hamburger" \ - # exp_manager.exp_dir=examples/multimodal/text_to_image/nerf/dreamfusion_results - # - # rm -rf examples/multimodal/text_to_image/nerf/dreamfusion_results - # - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - # if: "failure()" - - Speech_Checkpoints_tests: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'Speech_Checkpoints_tests') - with: - RUNNER: self-hosted-azure-gpus-1 - TIMEOUT: 20 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/Speech_Checkpoints_tests.sh - AFTER_SCRIPT: | - rm -f examples/asr/evaluation_transcripts.json - - L2_Stable_Diffusion_Training: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_Stable_Diffusion_Training') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_Stable_Diffusion_Training.sh - AFTER_SCRIPT: | - rm -rf examples/multimodal/text_to_image/sd_train_results - - L2_NeMo_2_GPT_Pretraining_no_transformer_engine: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_Pretraining_no_transformer_engine') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_Pretraining_no_transformer_engine.sh - AFTER_SCRIPT: | - rm -rf tests/collections/llm/gpt_pretrain_results - rm -rf tests/collections/llm/gpt_index_mappings - - L2_NeMo_2_llama3_pretraining_recipe: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_llama3_pretraining_recipe') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_llama3_pretraining_recipe.sh - - L2_NeMo_2_llama3_fault_tolerance_plugin: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_llama3_fault_tolerance_plugin') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_llama3_fault_tolerance_plugin.sh - - L2_NeMo_2_llama3_straggler_detection: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_llama3_straggler_detection') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_llama3_straggler_detection.sh - - L2_NeMo_2_GPT_DDP_Param_Parity_check: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_DDP_Param_Parity_check') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_DDP_Param_Parity_check.sh - - AFTER_SCRIPT: | - rm -rf tests/collections/llm/gpt_pretrain_results - rm -rf tests/collections/llm/gpt_index_mappings - - L2_NeMo_2_Hyena_DDP_Pretraining_Test: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Hyena_DDP_Pretraining_Test') - with: - RUNNER: self-hosted-azure # Assume runner has 2 GPUs - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Hyena_DDP_Pretraining_Test.sh - - AFTER_SCRIPT: | - rm -rf tests/collections/llm/hyena_pretrain_results/${{ github.run_id }} - - L2_NeMo_2_SSM_Pretraining: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_SSM_Pretraining') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_SSM_Pretraining.sh - - L2_NeMo_2_SSM_Finetuning: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_SSM_Finetuning') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_SSM_Finetuning.sh - - L2_NeMo_2_HF_MODEL_IMPORT: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_HF_MODEL_IMPORT') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_HF_MODEL_IMPORT.sh - - AFTER_SCRIPT: | - rm -rf ~/.cache/nemo/models - - L2_NeMo_2_jit_callback: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_jit_callback') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_jit_callback.sh - - L2_NeMo_2_T5_Pretraining: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_T5_Pretraining') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_T5_Pretraining.sh - AFTER_SCRIPT: | - rm -rf tests/collections/llm/t5_pretrain_results/${{ github.run_id }} - rm -rf tests/collections/llm/t5_index_mappings/${{ github.run_id }} - - L2_NeMo_2_T5_Finetuning: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_T5_Finetuning') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_T5_Finetuning.sh - AFTER_SCRIPT: | - rm -rf tests/collections/llm/t5_finetune_results/${{ github.run_id }} - - L2_NeMo_2_T5_LoRA: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_T5_LoRA') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_T5_LoRA.sh - AFTER_SCRIPT: | - rm -rf tests/collections/llm/t5_peft_results/${{ github.run_id }} - - L2_NeMo_2_NEVA_MOCK_PRETRAIN_TP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PRETRAIN_TP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_PRETRAIN_TP2.sh - - L2_NeMo_2_NEVA_MOCK_PRETRAIN_PP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PRETRAIN_PP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_PRETRAIN_PP2.sh - - L2_NeMo_2_NEVA_MOCK_PRETRAIN_CP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PRETRAIN_CP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_PRETRAIN_CP2.sh - - L2_NeMo_2_NEVA_MOCK_FINETUNE_TP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_FINETUNE_TP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_FINETUNE_TP2.sh - - L2_NeMo_2_NEVA_MOCK_FINETUNE_PP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_FINETUNE_PP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_FINETUNE_PP2.sh - - L2_NeMo_2_NEVA_MOCK_FINETUNE_CP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_FINETUNE_CP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_MOCK_FINETUNE_CP2.sh - - L2_NeMo_2_NEVA_LOAD_GENERATE: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NEVA_LOAD_GENERATE') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NEVA_LOAD_GENERATE.sh - - L2_NEMO_2_MLLAMA_Inference: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NEMO_2_MLLAMA_Inference') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NEMO_2_MLLAMA_Inference.sh - - L2_NeMo_2_MLLAMA_MOCK_FINETUNE_TP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_MLLAMA_MOCK_FINETUNE_TP2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_MLLAMA_MOCK_FINETUNE_TP2.sh - - L2_NeMo_2_Mixtral_Pretraining: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_Pretraining') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_Pretraining.sh - - L2_NeMo_2_GPT_SFT_TP1PP1_MBS1: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP1_MBS1') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP1PP1_MBS1.sh - - L2_NeMo_2_GPT_SFT_TP1PP1_MBS2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP1_MBS2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP1PP1_MBS2.sh - - L2_NeMo_2_GPT_SFT_TP1PP2_MBS2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP2_MBS2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP1PP2_MBS2.sh - - L2_NeMo_2_GPT_SFT_TP2PP1_MBS2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP2PP1_MBS2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP2PP1_MBS2.sh - - L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED.sh - - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1.sh - - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2.sh - - L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2.sh - - L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2.sh - - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED.sh - - L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED.sh - - L2_NeMo_2_GPT_CLoRA_TP1PP1_MBS1_PACKED: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_CLoRA_TP1PP1_MBS1_PACKED') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_CLoRA_TP1PP1_MBS1_PACKED.sh - - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat.sh - - L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2_exclude: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2_exclude') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2_exclude.sh - - L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2.sh - - L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1.sh - - L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1.sh - - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1.sh - - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1_exclude: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1_exclude') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1_exclude.sh - - L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1.sh - - L2_NEMO_2_LoRA_MERGE: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NEMO_2_LoRA_MERGE') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NEMO_2_LoRA_MERGE.sh - - L2_NEMO_2_LoRA_Export: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NEMO_2_LoRA_Export') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NEMO_2_LoRA_Export.sh - - L2_NEMO_2_LoRA_Inference: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NEMO_2_LoRA_Inference') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NEMO_2_LoRA_Inference.sh - - L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact.sh - - L2_NeMo_2_PTQ_Llama2_FP8_trtllm: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_PTQ_Llama2_FP8_trtllm') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_PTQ_Llama2_FP8_trtllm.sh - - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_ckpt - rm -rf /tmp/nemo2_ptq_engine - - L2_NeMo_2_PTQ_Llama2_FP8_nemo: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_PTQ_Llama2_FP8_nemo') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_PTQ_Llama2_FP8_nemo.sh - - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_ckpt - rm -rf /tmp/nemo2_ptq_ckpt - - L2_NeMo_2_Distill_Llama3_TP1PP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Distill_Llama3_TP1PP2') || needs.pre-flight.outputs.all == 'true' - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Distill_Llama3_TP1PP2.sh - - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_ckpt - rm -rf /tmp/distill_logs - - L2_NeMo_2_Prune_Llama_TP1PP2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Prune_Llama_TP1PP2') || needs.pre-flight.outputs.all == 'true' - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Prune_Llama_TP1PP2.sh - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_ckpt /tmp/pruned-llama - - L2_NeMo_2_Export_In_Framework: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Export_In_Framework') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Export_In_Framework.sh - - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_ckpt /tmp/lambada.json - - L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING.sh - - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_llava_next_results - - L2_NeMo_2_VLLM_EXPORT: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_VLLM_EXPORT') - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_VLLM_EXPORT.sh - - AFTER_SCRIPT: | - rm -rf /tmp/llama_head64 - rm -rf /tmp/nemo2_ckpt - rm -rf /tmp/vllm_from_nemo2 - - L2_NeMo_2_EVAL: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_EVAL') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_EVAL.sh - - AFTER_SCRIPT: | - rm -rf /tmp/trtllm_dir - - L2_NeMo_2_Auto_Configurator_TP1_PP1_MBS124: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Auto_Configurator_TP1_PP1_MBS124') - with: - RUNNER: self-hosted-azure-gpus-1 - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_NeMo_2_Auto_Configurator_TP1_PP1_MBS124.sh - AFTER_SCRIPT: | - rm -rf examples/llm/auto_configurator/auto_conf_logs - - L2_SpeechLM_LoRA_TP1PP1_MBS2: - needs: [pre-flight, cicd-test-container-build] - uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_SpeechLM_LoRA_TP1PP1_MBS2') || needs.pre-flight.outputs.all == 'true' - with: - RUNNER: self-hosted-azure - SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_SpeechLM_LoRA_TP1PP1_MBS2.sh - - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_speechlm_lora/${{ github.run_id }} - - Nemo_CICD_Test: - needs: - - pre-flight - - cicd-import-tests - - - L0_Unit_Tests_GPU_ASR - - L0_Unit_Tests_GPU_Audio - - L0_Unit_Tests_GPU_Common - - L0_Unit_Tests_GPU_LLM - - L0_Unit_Tests_GPU_Multimodal - - L0_Unit_Tests_GPU_TTS - - L0_Unit_Tests_GPU_Core - - L0_Unit_Tests_GPU_Hydra - - L0_Unit_Tests_GPU_Lightning - - L0_Unit_Tests_GPU_Others - - - L0_Unit_Tests_CPU_ASR - - L0_Unit_Tests_CPU_Audio - - L0_Unit_Tests_CPU_Common - - L0_Unit_Tests_CPU_LLM - - L0_Unit_Tests_CPU_Multimodal - - L0_Unit_Tests_CPU_TTS - - L0_Unit_Tests_CPU_Core - - L0_Unit_Tests_CPU_Hydra - - L0_Unit_Tests_CPU_Lightning - - L0_Unit_Tests_CPU_Others - - - ASR_dev_run_Speech_to_Text - - ASR_dev_run_Speech_to_Text_WPE_CitriNet - - ASR_dev_run_Speech_Pre-training_-_CitriNet - - Optional_ASR_dev_run_Speech_To_Text_Finetuning - - Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning - - ASR_dev_run_Speech_to_Text_WPE_-_Conformer - - ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer - - L2_Speech_to_Text_EMA - - L2_Speaker_dev_run_Speaker_Recognition - - L2_Speaker_dev_run_Speaker_Diarization - - L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer - - L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference - - L2_Speaker_dev_run_Speech_to_Label - - L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference - - L2_Speaker_dev_run_Clustering_Diarizer_Inference - - L2_Speaker_dev_run_Neural_Diarizer_Inference - - L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation - - L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader - - L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader - - L2_ASR_Adapters_Linear_Adapters - - L2_ASR_Adapters_RelPos_MHA_Adapters - - L2_Speech_Transcription_Speech_to_Text_Transcribe - - L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav - - L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3 - - L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference - - L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference - - L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Post-LN - - L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Pre-LN - - L2_NMT_Attention_is_All_You_Need_Training_NMT_Multi-Validation - - L2_NMT_Attention_is_All_You_Need_Inference - - L2_NMT_Attention_is_All_You_Need_Finetuning - - L2_NMT_Tarred_Dataset_Creation_Auto_Tarred_Dataset_Creation - - L2_NMT_Tarred_Dataset_Creation_Script_Tarred_Dataset_Creation - - L2_Megatron_NMT_Training_TP2 - - L2_TTS_Fast_dev_runs_1_Tacotron_2 - - L2_TTS_Fast_dev_runs_1_WaveGlow - - L2_TTS_Fast_dev_runs_1_FastPitch - #- OPTIONAL_L2_TTS_Fast_dev_runs_1_RADTTS - - L2_TTS_Fast_dev_runs_1_Hifigan - - Speech_Checkpoints_tests - - L2_Stable_Diffusion_Training - - L2_NeMo_2_NEVA_MOCK_PRETRAIN_TP2 - - L2_NeMo_2_NEVA_MOCK_PRETRAIN_PP2 - - L2_NeMo_2_NEVA_MOCK_PRETRAIN_CP2 - - L2_NeMo_2_NEVA_MOCK_FINETUNE_TP2 - - L2_NeMo_2_NEVA_MOCK_FINETUNE_PP2 - - L2_NeMo_2_NEVA_MOCK_FINETUNE_CP2 - - L2_NeMo_2_NEVA_LOAD_GENERATE - - L2_NeMo_2_MLLAMA_MOCK_FINETUNE_TP2 - - L2_NEMO_2_MLLAMA_Inference - - L2_NeMo_2_GPT_Pretraining_no_transformer_engine - - L2_NeMo_2_GPT_DDP_Param_Parity_check - - L2_NeMo_2_HF_MODEL_IMPORT - - L2_NeMo_2_llama3_pretraining_recipe - - L2_NeMo_2_llama3_fault_tolerance_plugin - - L2_NeMo_2_llama3_straggler_detection - - L2_HF_Transformer_PEFT_notebook - - L2_HF_Transformer_PEFT - - L2_HF_Transformer_PEFT_nemorun - - L2_HF_Transformer_PEFT_2gpu - - L2_HF_Transformer_PEFT_2gpu_FSDP2 - - L2_HF_Transformer_PEFT_2gpu_FSDP2_liger - - L2_HF_Transformer_PEFT_2gpu_nemorun - - L2_HF_Transformer_SFT_notebook - - L2_HF_Transformer_SFT - - L2_HF_Transformer_SFT_nemorun - - L2_HF_Transformer_SFT_2gpu - - L2_HF_Transformer_SFT_2gpu_FSDP2 - - L2_VLM_HF_Transformer_PEFT - - L2_VLM_HF_Transformer_PEFT_FSDP2 - - L2_VLM_HF_Transformer_PEFT_4bit - - L2_VLM_HF_Transformer_SFT_FSDP2 - - L2_HF_Transformer_SFT_2gpu_nemorun - - L2_HF_Transformer_SFT_TE_Acceleration - - L2_HF_Transformer_PT - - L2_HF_Transformer_PT_nemorun - - L2_HF_Transformer_PT_2gpu - - L2_HF_Transformer_PT_2gpu_nemorun - - L2_HF_Transformer_PT_TE_Acceleration - - L2_HF_Transformer_SpeechLM_SFT_2gpu - - L2_NeMo_2_SSM_Pretraining - - L2_NeMo_2_SSM_Finetuning - - L2_NeMo_2_Hyena_DDP_Pretraining_Test - - L2_NeMo_2_T5_Pretraining - - L2_NeMo_2_T5_Finetuning - - L2_NeMo_2_T5_LoRA - - L2_NeMo_2_GPT_SFT_TP1PP1_MBS1 - - L2_NeMo_2_GPT_SFT_TP1PP1_MBS2 - - L2_NeMo_2_GPT_SFT_TP1PP2_MBS2 - - L2_NeMo_2_GPT_SFT_TP2PP1_MBS2 - - L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED - - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1 - - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2 - - L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2 - - L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2 - - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat - - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED - - L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED - - L2_NeMo_2_GPT_CLoRA_TP1PP1_MBS1_PACKED - - L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2 - - L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1 - - L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1 - - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1 - - L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1 - - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1_exclude - - L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2_exclude - - L2_NEMO_2_LoRA_MERGE - - L2_NEMO_2_LoRA_Export - - L2_NEMO_2_LoRA_Inference - - L2_NeMo_2_Mixtral_Pretraining - - L2_NeMo_2_Auto_Configurator_TP1_PP1_MBS124 - - L2_Speech_to_Text_AED - - L2_Speech_Estimate_Duration_Bins - - L2_Speech_Batch_Size_OOMptimizer - # - Optional_L2_Speech_Batch_Size_OOMptimizer_Canary - - L2_Speech_Transcription_Canary_Transcribe_Full_Manifest - - L2_Speech_Transcription_Canary_Transcribe_With_Prompt - - L2_Speech_Transcription_Canary_Transcribe_Audio_Dir - - L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact - - L2_NeMo_2_PTQ_Llama2_FP8_trtllm - - L2_NeMo_2_PTQ_Llama2_FP8_nemo - - L2_NeMo_2_Distill_Llama3_TP1PP2 - - L2_NeMo_2_Prune_Llama_TP1PP2 - - L2_NeMo_2_Export_In_Framework - - L2_NeMo_2_jit_callback - - L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING - - L2_HF_Transformer_SFT_FSDP2_2gpu - - L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2 - - L2_NeMo_2_VLLM_EXPORT - - L2_NeMo_2_EVAL - - L2_SpeechLM_LoRA_TP1PP1_MBS2 + - L2_TTS_Fast_dev_runs_Magpietts_config1 + + # - ASR_dev_run_Speech_to_Text + # - ASR_dev_run_Speech_to_Text_WPE_CitriNet + # - ASR_dev_run_Speech_Pre-training_-_CitriNet + # - Optional_ASR_dev_run_Speech_To_Text_Finetuning + # - Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning + # - ASR_dev_run_Speech_to_Text_WPE_-_Conformer + # - ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer + # - L2_Speech_to_Text_EMA + # - L2_Speaker_dev_run_Speaker_Recognition + # - L2_Speaker_dev_run_Speaker_Diarization + # - L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer + # - L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference + # - L2_Speaker_dev_run_Speech_to_Label + # - L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference + # - L2_Speaker_dev_run_Clustering_Diarizer_Inference + # - L2_Speaker_dev_run_Neural_Diarizer_Inference + # - L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation + # - L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader + # - L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader + # - L2_ASR_Adapters_Linear_Adapters + # - L2_ASR_Adapters_RelPos_MHA_Adapters + # - L2_Speech_Transcription_Speech_to_Text_Transcribe + # - L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav + # - L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3 + # - L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference + # - L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference + # - L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Post-LN + # - L2_NMT_Attention_is_All_You_Need_Training_NMT_Training_Pre-LN + # - L2_NMT_Attention_is_All_You_Need_Training_NMT_Multi-Validation + # - L2_NMT_Attention_is_All_You_Need_Inference + # - L2_NMT_Attention_is_All_You_Need_Finetuning + # - L2_NMT_Tarred_Dataset_Creation_Auto_Tarred_Dataset_Creation + # - L2_NMT_Tarred_Dataset_Creation_Script_Tarred_Dataset_Creation + # - L2_Megatron_NMT_Training_TP2 + # - L2_TTS_Fast_dev_runs_1_Tacotron_2 + # - L2_TTS_Fast_dev_runs_1_WaveGlow + # - L2_TTS_Fast_dev_runs_1_FastPitch + # #- OPTIONAL_L2_TTS_Fast_dev_runs_1_RADTTS + # - L2_TTS_Fast_dev_runs_1_Hifigan + # - Speech_Checkpoints_tests + # - L2_Stable_Diffusion_Training + # - L2_NeMo_2_NEVA_MOCK_PRETRAIN_TP2 + # - L2_NeMo_2_NEVA_MOCK_PRETRAIN_PP2 + # - L2_NeMo_2_NEVA_MOCK_PRETRAIN_CP2 + # - L2_NeMo_2_NEVA_MOCK_FINETUNE_TP2 + # - L2_NeMo_2_NEVA_MOCK_FINETUNE_PP2 + # - L2_NeMo_2_NEVA_MOCK_FINETUNE_CP2 + # - L2_NeMo_2_NEVA_LOAD_GENERATE + # - L2_NeMo_2_MLLAMA_MOCK_FINETUNE_TP2 + # - L2_NEMO_2_MLLAMA_Inference + # - L2_NeMo_2_GPT_Pretraining_no_transformer_engine + # - L2_NeMo_2_GPT_DDP_Param_Parity_check + # - L2_NeMo_2_HF_MODEL_IMPORT + # - L2_NeMo_2_llama3_pretraining_recipe + # - L2_NeMo_2_llama3_fault_tolerance_plugin + # - L2_NeMo_2_llama3_straggler_detection + # - L2_HF_Transformer_PEFT_notebook + # - L2_HF_Transformer_PEFT + # - L2_HF_Transformer_PEFT_nemorun + # - L2_HF_Transformer_PEFT_2gpu + # - L2_HF_Transformer_PEFT_2gpu_FSDP2 + # - L2_HF_Transformer_PEFT_2gpu_FSDP2_liger + # - L2_HF_Transformer_PEFT_2gpu_nemorun + # - L2_HF_Transformer_SFT_notebook + # - L2_HF_Transformer_SFT + # - L2_HF_Transformer_SFT_nemorun + # - L2_HF_Transformer_SFT_2gpu + # - L2_HF_Transformer_SFT_2gpu_FSDP2 + # - L2_VLM_HF_Transformer_PEFT + # - L2_VLM_HF_Transformer_PEFT_FSDP2 + # - L2_VLM_HF_Transformer_PEFT_4bit + # - L2_VLM_HF_Transformer_SFT_FSDP2 + # - L2_HF_Transformer_SFT_2gpu_nemorun + # - L2_HF_Transformer_SFT_TE_Acceleration + # - L2_HF_Transformer_PT + # - L2_HF_Transformer_PT_nemorun + # - L2_HF_Transformer_PT_2gpu + # - L2_HF_Transformer_PT_2gpu_nemorun + # - L2_HF_Transformer_PT_TE_Acceleration + # - L2_HF_Transformer_SpeechLM_SFT_2gpu + # - L2_NeMo_2_SSM_Pretraining + # - L2_NeMo_2_SSM_Finetuning + # - L2_NeMo_2_Hyena_DDP_Pretraining_Test + # - L2_NeMo_2_T5_Pretraining + # - L2_NeMo_2_T5_Finetuning + # - L2_NeMo_2_T5_LoRA + # - L2_NeMo_2_GPT_SFT_TP1PP1_MBS1 + # - L2_NeMo_2_GPT_SFT_TP1PP1_MBS2 + # - L2_NeMo_2_GPT_SFT_TP1PP2_MBS2 + # - L2_NeMo_2_GPT_SFT_TP2PP1_MBS2 + # - L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED + # - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1 + # - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2 + # - L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2 + # - L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2 + # - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat + # - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED + # - L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED + # - L2_NeMo_2_GPT_CLoRA_TP1PP1_MBS1_PACKED + # - L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2 + # - L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1 + # - L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1 + # - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1 + # - L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1 + # - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1_exclude + # - L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2_exclude + # - L2_NEMO_2_LoRA_MERGE + # - L2_NEMO_2_LoRA_Export + # - L2_NEMO_2_LoRA_Inference + # - L2_NeMo_2_Mixtral_Pretraining + # - L2_NeMo_2_Auto_Configurator_TP1_PP1_MBS124 + # - L2_Speech_to_Text_AED + # - L2_Speech_Estimate_Duration_Bins + # - L2_Speech_Batch_Size_OOMptimizer + # # - Optional_L2_Speech_Batch_Size_OOMptimizer_Canary + # - L2_Speech_Transcription_Canary_Transcribe_Full_Manifest + # - L2_Speech_Transcription_Canary_Transcribe_With_Prompt + # - L2_Speech_Transcription_Canary_Transcribe_Audio_Dir + # - L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact + # - L2_NeMo_2_PTQ_Llama2_FP8_trtllm + # - L2_NeMo_2_PTQ_Llama2_FP8_nemo + # - L2_NeMo_2_Distill_Llama3_TP1PP2 + # - L2_NeMo_2_Prune_Llama_TP1PP2 + # - L2_NeMo_2_Export_In_Framework + # - L2_NeMo_2_jit_callback + # - L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING + # - L2_HF_Transformer_SFT_FSDP2_2gpu + # - L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2 + # - L2_NeMo_2_VLLM_EXPORT + # - L2_NeMo_2_EVAL + # - L2_SpeechLM_LoRA_TP1PP1_MBS2 if: always() && github.event != 'push' runs-on: ubuntu-latest @@ -2039,7 +2051,7 @@ jobs: JOBS='[]' PAGE=1 while : ; do - JOBS_URL="https://api.github.com/repos/$REPOSITORY/actions/runs/$RUN_ID/jobs?page=$PAGE&per_page=100" + JOBS_URL="https://api.github.com/repos/$REPOSITORY/actions/runs/$RUN_ID/jobs?page=$PAGE&per_page=100" RESPONSE=$(curl -s -H "Authorization: token $GITHUB_TOKEN" $JOBS_URL | jq '.jobs') JOBS=$(echo -e "$JOBS\n$RESPONSE" | jq -cs 'add') if [[ $(echo $RESPONSE | jq 'length') -lt 100 ]]; then @@ -2061,7 +2073,7 @@ jobs: LOGS=$(echo $JOB | yq '(.value.outputs.log | @base64d)' | tr -d '"') LOGS=$([[ $(echo $LOGS | wc -c) -gt 0 ]] && echo -E "\`\`\`\n$LOGS\n\`\`\`" || echo "") LOGS=$([[ $(echo $JOB | yq '.value.outputs.potential_infra_failure') == "true" ]] && echo -E "$LOGS\n\ncc: $SLACK_WEBHOOK_ADMIN" || echo -E "$LOGS") - + SUMMARY=$(echo "$SUMMARY" | jq \ --arg pr "<$PR_URL|$PR_TITLE>" \ --arg job "<$JOB_URL|$JOB_NAME>" \ diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index a7158e6393ff..f778da20f8d0 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -1,3 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import json import os @@ -31,7 +44,7 @@ class MagpieTTSModelOfflinePODataGen(MagpieTTSModel): This class is used in 'test' mode and leverages trainer.test() for multi-GPU/multi-node inference. Saves the predicted audio files and logs the CER/WER metrics as individual json files for each audio. """ - + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) if cfg.get('pref_set_language', "en") == "en": @@ -143,8 +156,8 @@ def test_step(self, batch, batch_idx): class MagpieTTSModelOfflinePO(MagpieTTSModel): """ - MagpieTTS_Model_OfflinePO is a class that extends MagpieTTS_Model to support - offline preference optimization (DPO, IPO, RPO). + MagpieTTS_Model_OfflinePO is a class that extends MagpieTTS_Model to support + offline preference optimization (DPO, IPO, RPO). Set cfg.model.dpo_loss_type to 'dpo', 'ipo', or 'rpo' to use the corresponding loss. """ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): diff --git a/scripts/magpietts/codec_extraction.py b/scripts/magpietts/codec_extraction.py index 83c164a9a6c8..0266bc5e7039 100644 --- a/scripts/magpietts/codec_extraction.py +++ b/scripts/magpietts/codec_extraction.py @@ -1,3 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import torch from torch.utils.data import Dataset, DataLoader diff --git a/scripts/magpietts/eval_squimmos.py b/scripts/magpietts/eval_squimmos.py index 8c40994ed309..047ee0b743b0 100644 --- a/scripts/magpietts/eval_squimmos.py +++ b/scripts/magpietts/eval_squimmos.py @@ -1,4 +1,16 @@ -from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE import os import json import torch @@ -22,7 +34,7 @@ def compute_mean_and_confidence_interval(measurements, confidence=0.95): std_err = stats.sem(measurements) confidence_interval = std_err * stats.t.ppf((1 + confidence) / 2, len(measurements) - 1) - + return "{:.4f} +/- {:.4f}".format(mean, confidence_interval), mean, confidence_interval def main(): @@ -54,7 +66,7 @@ def main(): with torch.no_grad(): squm_mos_score = squim_mos_model(pred_wav, gt_wav) squim_score_list.append(squm_mos_score.item()) - + mean_with_ci, mean, confidence_interval = compute_mean_and_confidence_interval(squim_score_list) # Add to audio_dir,mean_with_ci to csv with open(out_file, "a") as f: diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index 8436df0538fe..05b9fd66da4e 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -1,3 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. dataset_meta_info = { 'vctk': { 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths.json', @@ -29,12 +42,12 @@ 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/dev_clean_withContextAudioPaths_withTargetCodes_evalset_mid.json', 'audio_dir' : '/datap/misc/Datasets/LibriTTS', 'feature_dir' : '/datap/misc/Datasets/LibriTTS', - }, + }, 'libri_dev_clean_eval_tiny': { 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/dev_clean_withContextAudioPaths_withTargetCodes_evalset_tiny.json', 'audio_dir' : '/datap/misc/Datasets/LibriTTS', 'feature_dir' : '/datap/misc/Datasets/LibriTTS', - }, + }, 'libri_val': { 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360_val.json', 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 74d862ec685e..f70860d904d5 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -1,3 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import json import os diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 69a255c18c15..86ac0f52525c 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -1,3 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import copy import glob @@ -411,7 +424,7 @@ def main(): start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, confidence_level=args.confidence_level, use_local_transformer=args.use_local_transformer, - maskgit_n_steps=args.maskgit_n_steps, + maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks ) else: diff --git a/scripts/tts_dataset_to_lhotse/create_shars.py b/scripts/tts_dataset_to_lhotse/create_shars.py index ad7663e86fc7..f680ba80c877 100644 --- a/scripts/tts_dataset_to_lhotse/create_shars.py +++ b/scripts/tts_dataset_to_lhotse/create_shars.py @@ -1,3 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pathlib import Path import json import os diff --git a/tests/collections/tts/modules/test_audio_codec_modules.py b/tests/collections/tts/modules/test_audio_codec_modules.py index e1429df4fb70..325fde5b3bf7 100644 --- a/tests/collections/tts/modules/test_audio_codec_modules.py +++ b/tests/collections/tts/modules/test_audio_codec_modules.py @@ -240,6 +240,7 @@ def test_rvq_eval(self, num_codebooks: int): torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') + @pytest.mark.pleasefixme @pytest.mark.unit @pytest.mark.parametrize('num_groups', [1, 2, 4]) @pytest.mark.parametrize('num_codebooks', [1, 4]) diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_config1.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_config1.sh new file mode 100644 index 000000000000..a6aaa2013350 --- /dev/null +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_config1.sh @@ -0,0 +1,33 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts.py \ + --config-name magpietts_dc_en \ + +train_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_train.json" \ + +train_ds_meta.an4.audio_dir="/" \ + +train_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +train_ds_meta.an4.feature_dir=null \ + +val_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val.json" \ + +val_ds_meta.an4.audio_dir="/" \ + +val_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +val_ds_meta.an4.feature_dir=null \ + max_epochs=1 \ + batch_size=4 \ + model.codecmodel_path="/home/TestData/tts/21fps_causal_codecmodel.nemo" \ + trainer.devices="[0]" \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.strategy=auto \ + model.train_ds.dataloader_params.num_workers=0 \ + model.validation_ds.dataloader_params.num_workers=0 \ + ~trainer.check_val_every_n_epoch From 00be498f269281ebd58cbd9ab93426e7034d7b47 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Wed, 28 May 2025 06:30:29 -0700 Subject: [PATCH 038/113] BPE char tokenizer (#13594) * BPE char tokenizer in progress Signed-off-by: Shehzeen Hussain * added transformer module to CharAwareSubwordEncoder Signed-off-by: Shehzeen Hussain * limit max length in char transformer Signed-off-by: Shehzeen Hussain * handle special tokens correctly in BPE tokenizer. Refactor code and remove unused imports Signed-off-by: Shehzeen Hussain * correctly handle pad id. added docstrings Signed-off-by: Shehzeen Hussain * Update nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py Co-authored-by: Jason Signed-off-by: Shehzeen Hussain --------- Signed-off-by: Shehzeen Hussain Co-authored-by: Jason --- .../text_to_speech/tts_tokenizers.py | 9 +- nemo/collections/tts/models/magpietts.py | 49 +++++-- .../tts/modules/magpietts_modules.py | 131 +++++++++++++++++- 3 files changed, 174 insertions(+), 15 deletions(-) diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index c5509d3ffa30..3d25289a2007 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -1118,7 +1118,14 @@ def __init__(self, tokenizers: List[Union[BaseTokenizer, PreTrainedTokenizerBase self.tokens = tokens self.tokenizer_names = tokenizer_names self.toknizer_offsets = toknizer_offsets - self.pad = self.tokenizers[tokenizer_names[0]].pad # Use the first tokenizer's pad token + # Define aggregated token's pad value from the first tokenizer's pad value + first_tokenizer = self.tokenizers[tokenizer_names[0]] + if hasattr(first_tokenizer, "pad_token_id"): # Defined in PreTrainedTokenizerBase subclasses + self.pad = first_tokenizer.pad_token_id + elif hasattr(first_tokenizer, "pad"): # Defined in BaseTokenizer subclasses + self.pad = first_tokenizer.pad + else: + raise ValueError("AggregatedTTSTokenizer could not find a padding token in the first tokenizer") def encode(self, text: str, tokenizer_name: str) -> List[int]: tokenizer = self.tokenizers[tokenizer_name] diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index e155f4351032..18ba9d69223b 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -11,15 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os import random -import string import time from typing import List - -import librosa -import numpy as np import soundfile as sf import torch import wandb @@ -31,25 +26,22 @@ from torch.utils.data import get_worker_info import nemo.collections.asr as nemo_asr -from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 from nemo.collections.tts.modules.aligner import AlignmentEncoder -from nemo.collections.tts.modules.magpietts_modules import SpecialAudioToken, LocalTransformerType, cosine_schedule +from nemo.collections.tts.modules.magpietts_modules import CharAwareSubwordEncoder, SpecialAudioToken, LocalTransformerType, cosine_schedule from nemo.collections.tts.parts.utils.helpers import ( binarize_attention_parallel, get_mask_from_lengths, plot_alignment_to_numpy, ) -from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging - def worker_init_fn(worker_id): # For mp.set_start_method("spawn", force=True) # The dataset class should be picklable, so we initialize non-picklable objects here @@ -61,7 +53,7 @@ def worker_init_fn(worker_id): ) dataset.text_tokenizer = tokenizer dataset.text_conditioning_tokenizer = text_conditioning_tokenizer - + class MagpieTTSModel(ModelPT): """ Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context @@ -109,7 +101,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.context_audio_eos_id = cfg.get('forced_context_audio_eos_id', num_audio_tokens + SpecialAudioToken.AUDIO_CONTEXT_EOS.value) self.num_all_tokens_per_codebook = cfg.get('forced_num_all_tokens_per_codebook',num_audio_tokens + len(SpecialAudioToken)) self.mask_token_id = cfg.get('forced_mask_token_id', num_audio_tokens + SpecialAudioToken.MASK_TOKEN.value) - + self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False) + # Setup tokenizer if hasattr(cfg, 'text_tokenizer'): # For backward compatibility for English-only models @@ -153,7 +146,29 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if self.model_type != 'decoder_pretrain_synthesizer': # Decoder pretrain synthesizer doesn't have transcript encoder/text embeddings - self.text_embedding = nn.Embedding(num_tokens, cfg.embedding_dim) + + if self.use_bpe_char_tokenizer: + # BPE char tokenizer + assert len(self.tokenizer.tokenizers) == 1, "BPE char tokenizer should only be used with one tokenizer" + tokenizer_name = self.tokenizer.tokenizer_names[0] + tokenizer = self.tokenizer.tokenizers[tokenizer_name] + subword_vocab = tokenizer.get_vocab() + # special tokens will be stored as it is in the char_vocab + # Each special token will only be mapped to one char id + special_vocab = { + '': self.bos_id, + '': self.eos_id, + } + self.cas_encoder = CharAwareSubwordEncoder( + d_embed=cfg.embedding_dim, + llm_tokenizer_vocab=subword_vocab, + subword_padding_idx=self.tokenizer.pad, + special_vocab=special_vocab + ) + else: + # Regular text embedding + self.text_embedding = nn.Embedding(num_tokens, cfg.embedding_dim) + self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) @@ -796,6 +811,14 @@ def scale_prior(self, prior, global_step): / (prior_end_step - prior_scaledown_start_step) ) return new_prior + + def embed_text(self, text, text_mask): + if self.use_bpe_char_tokenizer: + text_embedded = self.cas_encoder(text, subword_mask=text_mask) + else: + text_embedded = self.text_embedding(text) + + return text_embedded def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_context_size=0): # attention scores: List of (B, C, audio_timesteps, text_timesteps) @@ -830,8 +853,8 @@ def prepare_context_tensors(self, batch): if self.model_type != 'decoder_pretrain_synthesizer': text = batch['text'] text_lens = batch['text_lens'] - text_embedded = self.text_embedding(text) # (B, T, E) text_mask = get_mask_from_lengths(text_lens) # (B, T) + text_embedded = self.embed_text(text, text_mask) # (B, T, E) text_encoder_out = self.encoder(text_embedded, text_mask, cond=None, cond_mask=None)['output'] # (B, T, E) _attn_prior = batch.get('align_prior_matrix', None) _attn_prior = self.scale_prior(_attn_prior, self.global_step) diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index 72fc063cce65..1512940e5fbb 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -14,7 +14,10 @@ from enum import Enum, StrEnum, verify, CONTINUOUS, UNIQUE import torch - +from nemo.collections.tts.modules import transformer_2501 +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from torch import Tensor +from nemo.core.classes.module import NeuralModule class LocalTransformerType(StrEnum): """ @@ -52,3 +55,129 @@ def cosine_schedule(x: torch.Tensor): Used for MaskGit mask scheduling. """ return torch.cos(x * (torch.pi / 2)) + +def build_vocabs(subword_vocab: dict, subword_padding_idx: int, special_vocab: dict = None) -> tuple[dict, dict]: + """ + Builds the character vocabulary and the mapping from subword ids to character ids. + Args: + subword_vocab (dict): A dictionary of subword vocab items. Eg. + tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name) + subword_vocab = tokenizer.vocab + subword_padding_idx (int): The padding index for the subword vocabulary. + special_vocab (dict): items of special token dictionary (usually BOS, EOS) + eg. special_vocab = {'': 0, '': 1} + Returns: + subword_id_to_char_ids: A dictionary mapping subword ids to character ids. + char_vocab: A dictionary mapping character ids to their corresponding characters. + """ + org_char_vocab = {subword: subword_id for subword, subword_id in subword_vocab.items() if len(subword) == 1} + + # Add special tokens directly to char vocab + if special_vocab is not None: + for special_token, special_token_id in special_vocab.items(): + if special_token in org_char_vocab: + raise ValueError(f"Special token {special_token} already exists in the character vocabulary.") + org_char_vocab[special_token] = special_token_id + + sorted_char_vocab = dict(sorted(org_char_vocab.items(), key=lambda x: x[1])) + char_vocab = {k: i for i, (k, _) in enumerate(sorted_char_vocab.items())} + assert sorted(char_vocab.values()) == list(range(len(char_vocab))) + subword_id_to_char_ids = { + subword_id: tuple(char_vocab[char] for char in subword) for subword, subword_id in subword_vocab.items() + } + + # Creating mapping from subword ids of special tokens to their char ids + if special_vocab is not None: + for special_token, special_token_id in special_vocab.items(): + if special_token in subword_id_to_char_ids: + raise ValueError(f"Special token {special_token} already exists in the subword id Vocabulary.") + subword_id_to_char_ids[special_token_id] = (char_vocab[special_token],) + + assert max(subword_id_to_char_ids) == len(subword_id_to_char_ids) - 1 + + # Always add padding token to the end of the vocab (this is the convention used in the original code) + subword_id_to_char_ids[subword_padding_idx] = (len(char_vocab),) + + return subword_id_to_char_ids, char_vocab + +class CharAwareSubwordEncoder(NeuralModule): + """ + Char-aware subword encoder for the MagpieTTS model. + This module takes subword ids as input, maps them to character ids, and then applies a transformer encoder to the character embeddings. + The output is a tensor of shape (batch_size, max_subword_length, d_embed). + """ + def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: int, special_vocab: dict = None): + """ + Args: + d_embed (int): The dimension of the embedding. + llm_tokenizer_vocab (dict): A dictionary of subword vocab items. Eg. + tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name) + llm_tokenizer_vocab = tokenizer.vocab + subword_padding_idx (int): The padding index for the subword vocabulary. + special_vocab (dict): items of special token dictionary (usually BOS, EOS) + eg. special_vocab = {'': 30001, '': 30002} + """ + super().__init__() + self.subword_id_to_char_ids, self.char_vocab = build_vocabs(llm_tokenizer_vocab, subword_padding_idx, special_vocab) + self.embed_tokens = torch.nn.Embedding(self.vocab_size+1, d_embed, padding_idx=self.vocab_size) + self.encoder = transformer_2501.Transformer( + n_layers=1, + d_model=d_embed, + d_ffn=d_embed * 4, + sa_n_heads=8, + kernel_size=1, + max_length_causal_mask=256 + ) + + @property + def vocab_size(self): + return len(self.char_vocab) + + def prepare_inputs(self, subword_ids: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: + device = subword_ids.device + + subword_id_list = torch.masked_select(subword_ids, padding_mask).cpu().tolist() + char_id_list = [list(self.subword_id_to_char_ids[x]) for x in subword_id_list] + + char_lengths = torch.tensor([len(x) for x in char_id_list], dtype=torch.long, device=device) + batch_size = char_lengths.size(0) + + char_ids = torch.full((batch_size, int(char_lengths.max().item())), self.vocab_size, dtype=torch.long) + for i in range(batch_size): + char_ids[i, : char_lengths[i]] = torch.tensor(char_id_list[i]) + char_ids = char_ids.to(device=device) + return char_ids, char_lengths + + def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Tensor: + """ + Args: + subword_ids (Tensor): A tensor of shape (batch_size, max_subword_length) containing the subword ids. + subword_mask (Tensor | None): A tensor of shape (batch_size, max_subword_length) containing the mask for the subword ids. + If None, a mask of ones will be used. + Returns: + Tensor: A tensor of shape (batch_size, max_subword_length, d_embed) containing the subword embeddings. + """ + device = subword_ids.device + if subword_mask is None: + subword_mask = torch.ones_like(subword_ids).bool() + else: + subword_mask = subword_mask.bool() + + if subword_mask.ndim == 3: + subword_mask = subword_mask.squeeze(-1) + + char_ids, char_lengths = self.prepare_inputs(subword_ids, subword_mask) + char_mask = get_mask_from_lengths(char_lengths) + char_emb = self.embed_tokens(char_ids) + # char emb has the shape [B*T, N, channels], where N is the max number of chars tokens decoded from bpe tokens + x = self.encoder( + x=char_emb, + x_mask=char_mask + )['output'] + + # Get average embedding over the chars + mean_emb = ((x / char_mask.unsqueeze(-1).sum(1, keepdim=True)) * char_mask.unsqueeze(-1)).sum(1) + subword_emb = torch.zeros((subword_mask.size(0), subword_mask.size(1), mean_emb.size(-1)), device=device) + subword_emb[subword_mask.unsqueeze(-1).expand(-1, -1, mean_emb.size(-1))] = mean_emb.view(-1) + + return subword_emb \ No newline at end of file From a6c73f92cb70d863d48d2ee42e2e92af0c6b0a71 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Wed, 28 May 2025 09:37:21 -0700 Subject: [PATCH 039/113] Python 3.10 compatibility (#13753) - Remove dependency on StrEnum (requires Python >= 3.11) and replace it with NeMo's PrettyStrEnum - Remove use of @verify decorator from Enum class (requires Python >= 3.11) Signed-off-by: Fejgin, Roy Co-authored-by: Jason --- nemo/collections/tts/modules/magpietts_modules.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index 1512940e5fbb..9751b51508f7 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum, StrEnum, verify, CONTINUOUS, UNIQUE +from enum import Enum +from nemo.utils.enum import PrettyStrEnum import torch from nemo.collections.tts.modules import transformer_2501 from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths from torch import Tensor from nemo.core.classes.module import NeuralModule -class LocalTransformerType(StrEnum): +class LocalTransformerType(PrettyStrEnum): """ Enum for the type of local transformer to use in the MagpieTTS model. These strings are the values allowed in the YAML config file. @@ -30,7 +31,6 @@ class LocalTransformerType(StrEnum): MASKGIT = "maskgit" -@verify(CONTINUOUS, UNIQUE) class SpecialAudioToken(Enum): """ Enum for the special tokens to use in the MagpieTTS model. From de1fe3993348942fce94b0204c5838e05fe4ecb2 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Fri, 30 May 2025 10:04:49 -0700 Subject: [PATCH 040/113] =?UTF-8?q?Add=20Fr=C3=A9chet=20codec=20distance?= =?UTF-8?q?=20metric=20(#13553)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Squashed commit of the following: commit cc615075d7286fa7274cb87d61989471721928c6 Merge: d0a5ad7a69 a704b9f3c0 Author: Fejgin, Roy Date: Mon May 12 11:15:33 2025 -0700 Merge branch 'magpietts_2503_frechet' of github.com:rfejgin/NeMo into magpietts_2503_frechet commit d0a5ad7a69669641de2cfc2116bbaeac9abf8973 Author: Fejgin, Roy Date: Mon May 12 11:15:15 2025 -0700 Comments and cleanup Signed-off-by: Fejgin, Roy commit a704b9f3c070fed931af7c2adf13bd14ebc21fba Merge: 9e9f1fa3a3 38234592b6 Author: Fejgin, Roy Date: Wed May 7 16:45:56 2025 -0700 Merge remote-tracking branch 'blisc/magpietts_2503' into magpietts_2503_frechet commit 9e9f1fa3a381a88e37768b9495868667759b6686 Author: Fejgin, Roy Date: Tue May 6 11:18:56 2025 -0700 Simplify how we calculate the codec feature dimension. Turns out there is an API for this for both FSQ and RVQ quantizers -- so we use it here. commit 416135db934c633b2a0c25196ed091d1f62c31ac Author: Fejgin, Roy Date: Wed Apr 30 22:20:26 2025 -0700 Comments commit 0af52e7f6c6ed6e581eb6988260c48571323ce86 Author: Fejgin, Roy Date: Wed Apr 30 21:31:48 2025 -0700 Comments commit 7aa2f5f5094750226ed04b5ba3c4a2c0155966b9 Author: Fejgin, Roy Date: Wed Apr 30 21:13:24 2025 -0700 Cleanup an minor tweaks commit ca44c16cacde73556428bea406743fd61c6cad5d Author: Fejgin, Roy Date: Wed Apr 30 16:15:50 2025 -0700 Integrate FCD into infer_and_evaluate commit 73759fc54084dc1b8542657890ab78089b89175e Author: Fejgin, Roy Date: Wed Apr 30 15:18:32 2025 -0700 Add Frechet Codec Distance metric implementation Signed-off-by: Fejgin, Roy * Addres PR comments and some additional clenaup * Fix typos * Create tensors direcly on device * Rename update_from_codes() to update() * Use AudioSegment() directly instead of _read_audio() * Added some PyTest tests for the FCD metric (fairly basic for now) * Use NeMo warning logging rather than the built-in Python one Signed-off-by: Fejgin, Roy * Cleanup Signed-off-by: Fejgin, Roy * Cleanup and optimize tensor creation Signed-off-by: Fejgin, Roy * FCD test: use more broadly-supported API codec emedding dimension Signed-off-by: Fejgin, Roy * Adding clarifying comments Signed-off-by: Fejgin, Roy * Error handling and copyright notices Signed-off-by: Fejgin, Roy * Fix typo * FCD: address some PR comments Comment; change WAV file reading; touch-ups. Signed-off-by: Fejgin, Roy * Fix integration of FCD into the evaluation scripts We now write out codec codes to file and then read them back during evaluate_generated_audio, similar to what we do for the wav files. Since these files aren't likely to be of interest after evaluation, we automatically delete them once done. Signed-off-by: Fejgin, Roy * Remove an assert as per PR comments Signed-off-by: Fejgin, Roy * Add a warning if FCD is negative Signed-off-by: Fejgin, Roy * Address PR comments Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- nemo/collections/tts/modules/fcd_metric.py | 271 ++++++++++++++++++ scripts/magpietts/evaluate_generated_audio.py | 59 +++- scripts/magpietts/infer_and_evaluate.py | 27 +- .../tts/modules/test_fcd_metric.py | 103 +++++++ 4 files changed, 441 insertions(+), 19 deletions(-) create mode 100644 nemo/collections/tts/modules/fcd_metric.py create mode 100644 tests/collections/tts/modules/test_fcd_metric.py diff --git a/nemo/collections/tts/modules/fcd_metric.py b/nemo/collections/tts/modules/fcd_metric.py new file mode 100644 index 000000000000..5d703c6ddfbd --- /dev/null +++ b/nemo/collections/tts/modules/fcd_metric.py @@ -0,0 +1,271 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This is an experimental metric. It measures the Frechet Distance between distributions of generated and real +codec frames. The distance is measured in the embedding space of the codec. We get the embeddings +by dequantizing codec frames. + +Like all FD metrics, the metric operates on a dataset level. A large number of real and generated frames are needed for the metric to be reliable -- on the order of tens of thousands. + +The frames are currently considered independently, i.e. temporal relationships between are not captured (though this might +be useful to explore). +""" + +import warnings + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torchmetrics import Metric +import numpy as np +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.utils import logging + + +class CodecEmbedder(nn.Module): + """ + Embeds audio codec codes into the codec's continuous embedding space. + Accepts as input either a batch of codes or a path to an audio file. + """ + def __init__(self, codec: AudioCodecModel): + super().__init__() + self.codec = codec + + def codes_to_embedding(self, x: Tensor, x_len: Tensor) -> Tensor: + """ + Embeds a batch of audio codec codes into the codec's continuous embedding space. + """ + # x: (B, C, T + # x_len: (B,) + return self.codec.dequantize(tokens=x, tokens_len=x_len) + + def encode_from_file(self, audio_path: str) -> Tensor: + """ + Encodes an audio file into audio codec codes. + """ + audio_segment = AudioSegment.from_file( + audio_path, target_sr=self.codec.sample_rate) + assert np.issubdtype(audio_segment.samples.dtype, np.floating) + audio_min = audio_segment.samples.min() + audio_max = audio_segment.samples.max() + eps = 0.01 # certain ways of normalizing audio can result in samples that are slightly outside of [-1, 1] + if audio_min < (-1.0 - eps) or audio_max > (1.0 + eps): + logging.warning(f"Audio samples are not normalized: min={audio_min}, max={audio_max}") + samples = torch.tensor(audio_segment.samples, device=self.codec.device).unsqueeze(0) + audio_len = torch.tensor(samples.shape[1], device=self.codec.device).unsqueeze(0) + codes, codes_len = self.codec.encode(audio=samples, audio_len=audio_len) + return codes, codes_len + + +class FrechetCodecDistance(Metric): + """ + Computes the Frechet Codec Distance between two distributions of audio codec frames (real and generated). + This is done in codec embedding space, one frame at a time. We name this metric the Frechet Codec Distance (FCD). + """ + + """ + Parts of this are based on the following implementation of FID (Frechet Inception Distance) on images: + + https://github.com/pytorch/torcheval/blob/main/torcheval/metrics/image/fid.py + + # Copyright (c) Meta Platforms, Inc. and affiliates. + # All rights reserved. + # + # This source code is licensed under the BSD-style license found in the + # LICENSE file in the root directory of this source tree. + + Contents of original LICENSE file: + + # BSD License + # + # For torcheval software + # + # Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. + # + # Redistribution and use in source and binary forms, with or without modification, + # are permitted provided that the following conditions are met: + # + # * Redistributions of source code must retain the above copyright notice, this + # list of conditions and the following disclaimer. + # + # * Redistributions in binary form must reproduce the above copyright notice, + # this list of conditions and the following disclaimer in the documentation + # and/or other materials provided with the distribution. + # + # * Neither the name Meta nor the names of its contributors may be used to + # endorse or promote products derived from this software without specific + # prior written permission. + # + # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR + # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON + # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + """ + is_differentiable = False + higher_is_better = False + full_state_update = False + + def __init__( + self, + codec, + feature_dim: int, + ) -> None: + """ + Computes the Frechet Codec Distance between two distributions of audio codec codes (real and generated). + The original paper (FID on images): https://arxiv.org/pdf/1706.08500.pdf + + Args: + codec (AudioCodecModel): The codec model to use. + feature_dim (int): The number of features in the codec embedding space (usually 4*num_codebooks) + """ + super().__init__() + + # Set the model and put it in evaluation mode + self.model = CodecEmbedder(codec) + self.model.eval() + self.model.requires_grad_(False) + + # Initialize state variables used to compute FCD + self.add_state("real_sum", default=torch.zeros(feature_dim), dist_reduce_fx="sum") + self.add_state("real_cov_sum", default=torch.zeros((feature_dim, feature_dim)), dist_reduce_fx="sum") + self.add_state("fake_sum", default=torch.zeros(feature_dim), dist_reduce_fx="sum") + self.add_state("fake_cov_sum", default=torch.zeros((feature_dim, feature_dim)), dist_reduce_fx="sum") + self.add_state("num_real_frames", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") + self.add_state("num_fake_frames", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") + + def update_from_audio_file(self, audio_path: str, is_real: bool) -> Tensor: + """ + Takes a path to an audio file, embeds it, and updates the FCD metric. + """ + codes, codes_len = self.model.encode_from_file(audio_path) + self.update(codes, codes_len, is_real) + + def update(self, codes: Tensor, codes_len: Tensor, is_real: bool): + """ + Update the states with a batch of real or fake codes. + Takes pre-computed codec codes, embeds them, and updates the FCD metric. + + Args: + codes (Tensor): A batch of codec frames of shape (B, C, T). + codes_len (Tensor): A batch of lengths of the codec frames of shape (B,). + is_real (Boolean): Denotes if samples are real or not. + """ + assert codes.ndim == 3 + + # Dequantize the codes to a continuous representation + embeddings = self.model.codes_to_embedding( + codes, codes_len + ) # B, E, T where E is the codec's embedding dimension, usually 4*num_codebooks + + # keep only the valid frames + valid_frames = [] + for i in range(codes.shape[0]): + valid_frames.append(embeddings[i, :, : codes_len[i]].T) # T', E + embeddings = torch.cat(valid_frames, dim=0) # total_valid_frames, E + valid_frame_count = embeddings.shape[0] + + # Update the state variables used to compute FCD + if is_real: + self.num_real_frames += valid_frame_count + self.real_sum += torch.sum(embeddings, dim=0) + self.real_cov_sum += torch.matmul(embeddings.T, embeddings) + else: + self.num_fake_frames += valid_frame_count + self.fake_sum += torch.sum(embeddings, dim=0) + self.fake_cov_sum += torch.matmul(embeddings.T, embeddings) + + return self + + def compute(self) -> Tensor: + """ + Compute the FCD. + + Returns: + tensor: The FCD. + """ + + # If the user has not already updated with at lease one + # sample from each distribution, then we raise an Error. + if (self.num_real_frames == 0) or (self.num_fake_frames == 0): + logging.warning( + "Computing FD requires at least 1 real frame and 1 fake frame," + f"but currently running with {self.num_real_frames} real frames and {self.num_fake_frames} fake frames." + "Returning 0.0" + ) + return torch.tensor(0.0, device=self.device) + + # Compute the mean activations for each distribution + real_mean = (self.real_sum / self.num_real_frames).unsqueeze(0) + fake_mean = (self.fake_sum / self.num_fake_frames).unsqueeze(0) + + # Compute the covariance matrices for each distribution + real_cov_num = self.real_cov_sum - self.num_real_frames * torch.matmul(real_mean.T, real_mean) + real_cov = real_cov_num / (self.num_real_frames - 1) + fake_cov_num = self.fake_cov_sum - self.num_fake_frames * torch.matmul(fake_mean.T, fake_mean) + fake_cov = fake_cov_num / (self.num_fake_frames - 1) + + # Compute the Frechet Distance between the distributions + fd = self.calculate_frechet_distance(real_mean.squeeze(), real_cov, fake_mean.squeeze(), fake_cov) + # FD should be non-negative but due to numerical errors, it can be slightly negative + # Have seen -0.0011 in the past + if fd < -0.005: + logging.warning(f"FCD is negative, which is unexpected: {fd}") + return torch.clamp(fd, min=0.0) + + def calculate_frechet_distance( + self, + mu1: Tensor, + sigma1: Tensor, + mu2: Tensor, + sigma2: Tensor, + ) -> Tensor: + """ + Calculate the Frechet Distance between two multivariate Gaussian distributions. + + Args: + mu1 (Tensor): The mean of the first distribution. Shape: (feature_dim,) + sigma1 (Tensor): The covariance matrix of the first distribution. Shape: (feature_dim, feature_dim) + mu2 (Tensor): The mean of the second distribution. Shape: (feature_dim,) + sigma2 (Tensor): The covariance matrix of the second distribution. Shape: (feature_dim, feature_dim) + + Returns: + tensor: The Frechet Distance between the two distributions. + """ + # Compute the squared distance between the means + mean_diff = mu1 - mu2 + mean_diff_squared = mean_diff.square().sum(dim=-1) + + # Calculate the sum of the traces of both covariance matrices + trace_sum = sigma1.trace() + sigma2.trace() + + # Compute the eigenvalues of the matrix product of the real and fake covariance matrices + sigma_mm = torch.matmul(sigma1, sigma2) + eigenvals = torch.linalg.eigvals(sigma_mm) + + # Take the square root of each eigenvalue and take its sum + sqrt_eigenvals_sum = eigenvals.sqrt().real.sum(dim=-1) + + # Calculate the FCD using the squared distance between the means, + # the sum of the traces of the covariance matrices, and the sum of the square roots of the eigenvalues + fcd = mean_diff_squared + trace_sum - 2 * sqrt_eigenvals_sum + + return fcd diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index f70860d904d5..720705bf4abc 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -21,21 +21,29 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate_detail +from nemo.collections.tts.modules.fcd_metric import FrechetCodecDistance +from nemo.collections.tts.models import AudioCodecModel from transformers import WhisperProcessor, WhisperForConditionalGeneration import librosa -import evalset_config +import scripts.magpietts.evalset_config as evalset_config from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector -def find_sample_audios(audio_dir): +def find_generated_files(audio_dir, prefix, extension): file_list = [] for f in os.listdir(audio_dir): - if "predicted_audio" in f and f.endswith(".wav"): - audio_number = int(f.split("_")[-1].split(".wav")[0]) + if prefix in f and f.endswith(extension): + audio_number = int(f.split("_")[-1].split(extension)[0]) file_list.append((audio_number, os.path.join(audio_dir, f))) file_list.sort() file_list = [t[1] for t in file_list] return file_list +def find_generated_audio_files(audio_dir): + return find_generated_files(audio_dir=audio_dir, prefix="predicted_audio", extension=".wav") + +def find_generated_codec_files(audio_dir): + return find_generated_files(audio_dir=audio_dir, prefix="predicted_codes", extension=".pt") + def read_manifest(manifest_path): records = [] with open(manifest_path, 'r') as f: @@ -90,10 +98,14 @@ def extract_embedding(model, extractor, audio_path, device, sv_model_type): return embeddings.squeeze() -def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large"): - audio_file_lists = find_sample_audios(generated_audio_dir) +def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", + codecmodel_path=None): + audio_file_lists = find_generated_audio_files(generated_audio_dir) records = read_manifest(manifest_path) assert len(audio_file_lists) == len(records) + if codecmodel_path is not None: + codes_file_lists = find_generated_codec_files(generated_audio_dir) + assert len(codes_file_lists) == len(records) device = "cuda" @@ -125,7 +137,18 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo speaker_verification_model_alternate = speaker_verification_model_alternate.to(device) speaker_verification_model_alternate.eval() - + if codecmodel_path is not None: + codec = AudioCodecModel.restore_from(codecmodel_path, strict=False) + codec = codec.to(device) + codec.eval() + # The FCD metric measures a distance between generated and real codec frames. The distance + # is measured in the codec's embedding space. `codec_feature_dim` is the size of the codec's embedding vector. + # For example, for a group-FSQ codec with 8 codebooks with 4 values in each codebook, the embedding dimension is 8 x 4 = 32. + codec_feature_dim = codec.vector_quantizer.codebook_dim + fcd_metric = FrechetCodecDistance(codec=codec, feature_dim=codec_feature_dim).to(device) + else: + print("No codec model provided, skipping FCD metric") + fcd_metric = None filewise_metrics = [] pred_texts = [] @@ -138,13 +161,17 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo gt_audio_filepath = os.path.join(audio_dir, gt_audio_filepath) if context_audio_filepath is not None: context_audio_filepath = os.path.join(audio_dir, context_audio_filepath) + # Update the FCD metric for *real* codes + if fcd_metric is not None: + fcd_metric.update_from_audio_file(gt_audio_filepath, True) pred_audio_filepath = audio_file_lists[ridx] + if fcd_metric is not None: + pred_codes_filepath = codes_file_lists[ridx] try: if language == "en": with torch.no_grad(): - # import ipdb; ipdb.set_trace() pred_text = asr_model.transcribe([pred_audio_filepath])[0].text pred_text = process_text(pred_text) gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0].text @@ -176,6 +203,12 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo gt_texts.append(gt_text) gt_audio_texts.append(gt_audio_text) + # update FCD metric + if fcd_metric is not None: + predicted_codes = torch.load(pred_codes_filepath).unsqueeze(0) # B, C, T + predicted_codes_lens = torch.tensor([predicted_codes.size(-1)], dtype=torch.int, device=device) + fcd_metric.update(predicted_codes, predicted_codes_lens, False) + pred_context_ssim = 0.0 gt_context_ssim = 0.0 with torch.no_grad(): @@ -197,8 +230,6 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo pred_context_ssim_alternate = torch.nn.functional.cosine_similarity(pred_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() gt_context_ssim_alternate = torch.nn.functional.cosine_similarity(gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() - - filewise_metrics.append({ 'gt_text': gt_text, 'pred_text': pred_text, @@ -226,6 +257,13 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo # Sort filewise metrics by cer in reverse filewise_metrics.sort(key=lambda x: x['cer'], reverse=True) + # compute frechet distance for the whole test set + if fcd_metric is not None: + fcd = fcd_metric.compute().cpu().item() + fcd_metric.reset() + else: + fcd = 0.0 + avg_metrics = {} avg_metrics['cer_filewise_avg'] = sum([m['detailed_cer'][0] for m in filewise_metrics]) / len(filewise_metrics) avg_metrics['wer_filewise_avg'] = sum([m['detailed_wer'][0] for m in filewise_metrics]) / len(filewise_metrics) @@ -239,6 +277,7 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo avg_metrics['ssim_gt_context_avg_alternate'] = sum([m['gt_context_ssim_alternate'] for m in filewise_metrics]) / len(filewise_metrics) avg_metrics["cer_gt_audio_cumulative"] = word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=True)[0] avg_metrics["wer_gt_audio_cumulative"] = word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=False)[0] + avg_metrics["frechet_codec_distance"] = fcd pprint.pprint(avg_metrics) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 86ac0f52525c..52b55afc8192 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -18,8 +18,8 @@ import os import shutil -import evalset_config -import evaluate_generated_audio +import scripts.magpietts.evalset_config as evalset_config +import scripts.magpietts.evaluate_generated_audio as evaluate_generated_audio import numpy as np import scipy.stats as stats import soundfile as sf @@ -31,7 +31,6 @@ from nemo.collections.tts.data.text_to_speech_dataset import MagpieTTSDataset from nemo.collections.tts.models import MagpieTTSModel - def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0.90): metrics = {} for key in metric_keys: @@ -228,6 +227,7 @@ def run_inference( item_idx = 0 all_rtf_metrics = [] + codec_file_paths = [] for bidx, batch in enumerate(test_data_loader): print("Processing batch {} out of {} of dataset {}".format(bidx, len(test_data_loader), dataset)) batch_cuda ={} @@ -239,7 +239,7 @@ def run_inference( import time st = time.time() - predicted_audio, predicted_audio_lens, _, _, rtf_metrics, cross_attention_maps, _ = model.infer_batch( + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, rtf_metrics, cross_attention_maps, _ = model.infer_batch( batch_cuda, max_decoder_steps=440, temperature=temperature, @@ -256,6 +256,7 @@ def run_inference( use_local_transformer_for_inference=use_local_transformer, maskgit_n_steps=maskgit_n_steps ) + all_rtf_metrics.append(rtf_metrics) et = time.time() print(f"Time taken for inference: {et-st}", predicted_audio.size()) @@ -267,6 +268,9 @@ def run_inference( predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]] audio_path = os.path.join(pred_audio_dir, f"predicted_audio_{item_idx}.wav") sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate) + codes_path = os.path.join(pred_audio_dir, f"predicted_codes_{item_idx}.pt") + torch.save(predicted_codes[idx][:predicted_codes_lens[idx]], codes_path) + codec_file_paths.append(codes_path) context_audio_path = manifest_records[item_idx].get('context_audio_filepath', None) target_audio_path = manifest_records[item_idx].get('audio_filepath', None) if context_audio_path is not None: @@ -290,6 +294,7 @@ def run_inference( language=language, sv_model_type=sv_model, asr_model_name=asr_model_name, + codecmodel_path=codecmodel_path ) metrics_n_repeated.append(metrics) with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: @@ -305,24 +310,28 @@ def run_inference( all_experiment_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") if not os.path.exists(all_experiment_csv): with open(all_experiment_csv, "w") as f: - f.write("checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative\n") + f.write("checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,frechet_codec_distance\n") with open(all_experiment_csv, "a") as f: - f.write(f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']}\n") + f.write(f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']},{metrics['frechet_codec_distance']}\n") print(f"Wrote metrics for {checkpoint_name} and {dataset} to {all_experiment_csv}") + # Clean up temporary codec files + for codes_file in codec_file_paths: + os.remove(codes_file) metric_keys = ['cer_filewise_avg', 'wer_filewise_avg', 'cer_cumulative', 'wer_cumulative', 'ssim_pred_gt_avg', 'ssim_pred_context_avg', 'ssim_gt_context_avg', 'ssim_pred_gt_avg_alternate', 'ssim_pred_context_avg_alternate', 'ssim_gt_context_avg_alternate', - 'cer_gt_audio_cumulative', 'wer_gt_audio_cumulative' + 'cer_gt_audio_cumulative', 'wer_gt_audio_cumulative', 'frechet_codec_distance' ] metrics_mean_ci = compute_mean_and_confidence_interval(metrics_n_repeated, metric_keys, confidence=confidence_level) all_experiment_csv_with_ci = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") if not os.path.exists(all_experiment_csv_with_ci): with open(all_experiment_csv_with_ci, "w") as f: - f.write("checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative\n") + f.write("checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,frechet_codec_distance\n") with open(all_experiment_csv_with_ci, "a") as f: - f.write(f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']}\n") + f.write(f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']},{metrics_mean_ci['frechet_codec_distance']}\n") print(f"Wrote metrics with CI for {checkpoint_name} and {dataset} to {all_experiment_csv_with_ci}") + def main(): parser = argparse.ArgumentParser(description='Experiment Evaluation') diff --git a/tests/collections/tts/modules/test_fcd_metric.py b/tests/collections/tts/modules/test_fcd_metric.py new file mode 100644 index 000000000000..6c2807716b05 --- /dev/null +++ b/tests/collections/tts/modules/test_fcd_metric.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.tts.modules.fcd_metric import FrechetCodecDistance +from nemo.collections.tts.models import AudioCodecModel + + +class TestFrechetCodecDistance: + @pytest.fixture + def device(self): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + @pytest.fixture + def codec(self, device, scope="session"): + return AudioCodecModel.from_pretrained("nvidia/low-frame-rate-speech-codec-22khz").to(device) + + @pytest.fixture + def metric(self, codec, device): + codec_feature_dim = codec.vector_quantizer.codebook_dim + return FrechetCodecDistance(codec=codec, feature_dim=codec_feature_dim).to(device) + + @pytest.mark.unit + def test_same_distribution(self, metric, device, codec): + """Test that FCD is close to zero when comparing identical distributions.""" + B, C, T = 3, codec.num_codebooks, 20 + codes = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes_len = torch.randint(low=1, high=T, size=(B,), device=device) + + # Update with same codes for both real and fake + metric.update(codes, codes_len, is_real=True) + metric.update(codes, codes_len, is_real=False) + + eps = 0.01 + fcd = metric.compute() + assert fcd < eps and fcd >= 0, f"FCD value is {fcd} but should be close to 0" + metric.reset() + + @pytest.mark.unit + def test_different_distribution(self, metric, device, codec): + """Test that FCD is positive when comparing different distributions.""" + B, C, T = 3, codec.num_codebooks, 20 + + # Generate two different sets of codes + codes1 = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes2 = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes_len = torch.randint(low=1, high=T, size=(B,), device=device) + + metric.update(codes1, codes_len, is_real=True) + metric.update(codes2, codes_len, is_real=False) + + fcd = metric.compute() + assert fcd > 0, f"FCD value is {fcd} but should be positive for different distributions" + metric.reset() + + def test_empty_distribution(self, metric): + """Test that computing the FCD on empty distributions returns 0.""" + fcd = metric.compute() + assert fcd == 0.0 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.unit + def test_gpu_compatibility(self, metric, device, codec): + """Test that the metric works correctly on GPU.""" + assert metric.device.type == "cuda" + B, C, T = 3, codec.num_codebooks, 20 + codes = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes_len = torch.randint(low=1, high=T, size=(B,), device=device) + + metric.update(codes, codes_len, is_real=True) + metric.update(codes, codes_len, is_real=False) + + fcd = metric.compute() + + eps = 0.01 + assert isinstance(fcd, torch.Tensor) + assert fcd.device.type == "cuda" + assert fcd < eps and fcd >= 0, f"FCD value is {fcd} but should be close to 0" + + @pytest.mark.unit + def test_update_from_audio_file(self, metric): + """Test the update_from_audio_file method.""" + + # Test with both "real" and "fake" audio files (different files) + metric.update_from_audio_file("tests/.data/tts/mini_ljspeech/wavs/LJ019-0373.wav", is_real=True) + metric.update_from_audio_file("tests/.data/tts/mini_ljspeech/wavs/LJ050-0234.wav", is_real=False) + + fcd = metric.compute() + assert isinstance(fcd, torch.Tensor) + assert fcd > 0, f"FCD value is {fcd} but should be positive given that we tested different audio files" From 98d35f287ce14ec50248c45707b80307c67d96ee Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 3 Jun 2025 11:42:45 -0400 Subject: [PATCH 041/113] refactor 1: typos, yaml updates, code changes (#13677) Signed-off-by: Jason --- .../tts/conf/magpietts/magpietts_dc_en.yaml | 10 +- examples/tts/conf/magpietts/magpietts_en.yaml | 7 - .../magpietts/magpietts_inference_en.yaml | 8 - .../magpietts_inference_multilingual_v1.yaml | 8 - .../magpietts/magpietts_lhotse_dc_en.yaml | 8 +- .../magpietts/magpietts_multilingual_v1.yaml | 10 - nemo/collections/tts/models/magpietts.py | 207 ++++++++++-------- .../tts/modules/transformer_2501.py | 32 ++- 8 files changed, 127 insertions(+), 163 deletions(-) diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml index 075fe26f67f9..eef597a89f5e 100644 --- a/examples/tts/conf/magpietts/magpietts_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_dc_en.yaml @@ -12,13 +12,10 @@ weighted_sampling_steps_per_epoch: null train_ds_meta: ??? val_ds_meta: ??? -# Modify these values based on your sample rate -sample_rate: 22050 - model: model_type: "decoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. - context_duration_min: 3.0 + context_duration_min: 5.0 context_duration_max: 5.0 load_cached_codes_if_available: true prior_scaling_factor: 0.5 @@ -30,7 +27,6 @@ model: codecmodel_path: ??? max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} cfg_unconditional_prob: 0.1 # Alignment encoder parameters, to binarize the prior @@ -77,8 +73,6 @@ model: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${train_ds_meta} weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} min_duration: 0.2 max_duration: 20.0 @@ -92,8 +86,6 @@ model: dataset: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${val_ds_meta} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} min_duration: 0.2 max_duration: 20.0 diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index deeda0ed1e72..25dd489a2e2e 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -12,9 +12,6 @@ weighted_sampling_steps_per_epoch: null train_ds_meta: ??? val_ds_meta: ??? -# Modify these values based on your sample rate -sample_rate: 22050 - model: model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. @@ -79,8 +76,6 @@ model: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${train_ds_meta} weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} min_duration: 0.2 max_duration: 20.0 @@ -94,8 +89,6 @@ model: dataset: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${val_ds_meta} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} min_duration: 0.2 max_duration: 20.0 diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_inference_en.yaml index 9b46b6cb42ee..3c252139a7e5 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_en.yaml @@ -12,9 +12,6 @@ weighted_sampling_steps_per_epoch: null # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 test_ds_meta: ??? -# Modify these values based on your sample rate -sample_rate: 22050 - phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" model: @@ -41,8 +38,6 @@ model: max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} - # Alignment encoder parameters, to binarize the prior # This is used for attention-constrained training and inference use_alignment_encoder: false @@ -86,10 +81,8 @@ model: dataset: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${test_ds_meta} - sample_rate: ${sample_rate} min_duration: 0.5 max_duration: 20.0 - # speaker_path: ${speaker_path} dataloader_params: batch_size: ${batch_size} @@ -161,7 +154,6 @@ trainer: logger: false # Provided by exp_manager log_every_n_steps: 100 val_check_interval: 500 - # check_val_every_n_epoch: 10 benchmark: false gradient_clip_val: 2.5 diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml index 10af938fed6f..22d2b73fa3a0 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml @@ -12,9 +12,6 @@ weighted_sampling_steps_per_epoch: null # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 test_ds_meta: ??? -# Modify these values based on your sample rate -sample_rate: 22050 - phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" model: @@ -41,8 +38,6 @@ model: max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} - # Alignment encoder parameters, to binarize the prior # This is used for attention-constrained training and inference use_alignment_encoder: false @@ -134,10 +129,8 @@ model: dataset: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${test_ds_meta} - sample_rate: ${sample_rate} min_duration: 0.5 max_duration: 20.0 - # speaker_path: ${speaker_path} dataloader_params: batch_size: ${batch_size} @@ -209,7 +202,6 @@ trainer: logger: false # Provided by exp_manager log_every_n_steps: 100 val_check_interval: 500 - # check_val_every_n_epoch: 10 benchmark: false gradient_clip_val: 2.5 diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index d6b8c355b8e3..b8f19208149a 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -1,13 +1,12 @@ name: MagpieTTS-EN-Lhotse quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. -sample_rate: 22_050 model: use_lhotse: true model_type: "decoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. - context_duration_min: 3.0 + context_duration_min: 5.0 context_duration_max: 5.0 codec_model_name: "21fpsCausalDecoder" load_cached_codes_if_available: true @@ -18,7 +17,6 @@ model: alignment_loss_scale: 0.002 embedding_dim: 768 codecmodel_path: ??? - sample_rate: ${sample_rate} cfg_unconditional_prob: 0.1 # enable classifier-free guidance during traing by dropping out conditionals with this probability # Alignment encoder parameters, to binarize the prior @@ -66,7 +64,6 @@ model: dataset: min_duration: 0.2 - sample_rate: ${sample_rate} # need to override the default sample rate of 16000 batch_duration : ??? # in seconds. Adjust based on your GPU memory. quadratic_duration: ${quadratic_duration} use_bucketing: true @@ -93,7 +90,6 @@ model: dataset: min_duration: 0.2 - sample_rate: ${sample_rate} batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. quadratic_duration: ${quadratic_duration} use_bucketing: false @@ -177,7 +173,7 @@ exp_manager: entity: null project: null group: null - name: null + name: ${name} resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. create_checkpoint_callback: true checkpoint_callback_params: diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index 2c958b369886..bec174f81146 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -12,9 +12,6 @@ weighted_sampling_steps_per_epoch: null train_ds_meta: ??? val_ds_meta: ??? -# Modify these values based on your sample rate -sample_rate: 22050 - model: model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. @@ -34,8 +31,6 @@ model: max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} - # Alignment encoder parameters, to binarize the prior # This is used for attention-constrained training and inference use_alignment_encoder: false @@ -128,8 +123,6 @@ model: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${train_ds_meta} weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} min_duration: 0.5 max_duration: 20.0 @@ -142,8 +135,6 @@ model: dataset: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${val_ds_meta} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} min_duration: 0.5 max_duration: 20.0 @@ -217,7 +208,6 @@ trainer: logger: false # Provided by exp_manager log_every_n_steps: 100 val_check_interval: 500 - # check_val_every_n_epoch: 10 benchmark: false gradient_clip_val: 2.5 diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 18ba9d69223b..02ca887ed16e 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -53,7 +53,7 @@ def worker_init_fn(worker_id): ) dataset.text_tokenizer = tokenizer dataset.text_conditioning_tokenizer = text_conditioning_tokenizer - + class MagpieTTSModel(ModelPT): """ Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context @@ -86,6 +86,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # load codec codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False) + self.sample_rate = codec_model.sample_rate # del codec discriminator to free memory del codec_model.discriminator @@ -125,7 +126,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.bos_id = num_tokens - 2 self.eos_id = num_tokens - 1 - self.model_type = cfg.get('model_type', 'single_encoder_sv_tts') + self.model_type = cfg.get('model_type', None) self.pad_context_text_to_max_duration = self.model_type == 'decoder_context_tts' self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False) @@ -146,7 +147,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if self.model_type != 'decoder_pretrain_synthesizer': # Decoder pretrain synthesizer doesn't have transcript encoder/text embeddings - + if self.use_bpe_char_tokenizer: # BPE char tokenizer assert len(self.tokenizer.tokenizers) == 1, "BPE char tokenizer should only be used with one tokenizer" @@ -213,14 +214,14 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self._speaker_verification_model.freeze() #Lightning does requires_grad = False and self.eval() self.speaker_projection_layer = nn.Linear(cfg.speaker_emb_dim, cfg.embedding_dim) self.transcript_decoder_layers = [ - idx for idx in range(cfg.decoder.n_layers) + idx for idx in range(self.decoder.n_layers) ] # All layers are used for text elif self.model_type == 'multi_encoder_context_tts': self.transcript_decoder_layers = cfg.get('transcript_decoder_layers', [3, 4, 5, 6, 7, 8]) self.context_decoder_layers = cfg.get( 'context_decoder_layers', [0, 1, 2, 9, 10, 11] ) # For backward compatibility - multi_encoder_mapping = [None for _ in range(cfg.decoder.n_layers)] + multi_encoder_mapping = [None for _ in range(self.decoder.n_layers)] for layer in self.transcript_decoder_layers: multi_encoder_mapping[layer] = 0 # 0 means text goes to this layer, 1 means context goes to this layer for layer in self.context_decoder_layers: @@ -229,7 +230,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder)) elif self.model_type == 'decoder_context_tts': self.transcript_decoder_layers = [ - idx for idx in range(cfg.decoder.n_layers) + idx for idx in range(self.decoder.n_layers) ] # All layers are used for text elif self.model_type == 'decoder_pretrain_synthesizer': assert cfg.alignment_loss_scale == 0.0, "Alignment loss is not supported for decoder pretrain synthesizer" @@ -237,13 +238,34 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): raise ValueError(f"Unsupported model type {self.model_type}") self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') - alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0) - alignment_encoder_loss_scale = cfg.get('alignment_encoder_loss_scale', 0.0) - if alignment_loss_scale > 0.0: - self.alignment_loss = ForwardSumLoss(loss_scale=alignment_loss_scale) - if alignment_encoder_loss_scale > 0.0: - self.alignment_encoder_loss = ForwardSumLoss(loss_scale=alignment_encoder_loss_scale) - + self.alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0) + self.alignment_encoder_loss_scale = cfg.get('alignment_encoder_loss_scale', 0.0) + if self.alignment_loss_scale > 0.0: + self.alignment_loss = ForwardSumLoss(loss_scale=self.alignment_loss_scale) + if self.alignment_encoder_loss_scale > 0.0: + self.alignment_encoder_loss = ForwardSumLoss(loss_scale=self.alignment_encoder_loss_scale) + + # Define cfg parameters into self parameters + self.prior_end_step = self.cfg.prior_end_step + self.prior_scaledown_start_step = self.cfg.prior_scaledown_start_step + self.indefinite_prior_prob = self.cfg.get('indefinite_prior_prob', 0.0) + self.ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers) + self.cfg_unconditional_prob = self.cfg.get('cfg_unconditional_prob', 0.0) + self.decoder_input_dropout_prob = self.cfg.get('decoder_input_dropout_prob', 0.0) + self.binarize_attn_method = self.cfg.get('binarize_attn_method', 'argmax') + self.binarize_repeat_audio_factor = self.cfg.get('binarize_repeat_audio_factor', 2) + self.prior_future_decay = self.cfg.get('prior_future_decay', 1.0) + self.prior_past_decay = self.cfg.get('prior_past_decay', 1.0) + self.binarized_prior_epsilon = self.cfg.get('binarized_prior_epsilon', 0.0) + self.prior_future_context = self.cfg.get('prior_future_context', 1) + self.prior_past_context = self.cfg.get('prior_past_context', 1) + self.binarize_prior_after_step = self.cfg.get('binarize_prior_after_step', 0) + self.codebook_loss_scale = self.cfg.get('codebook_loss_scale', 1.0) + self.local_transformer_loss_scale = self.cfg.get('local_transformer_loss_scale', 1.0) + self.use_alignment_encoder = self.cfg.get('use_alignment_encoder', False) + self.use_prior_for_aligner = self.cfg.get('use_prior_for_aligner', False) + self.aligner_encoder_train_steps = self.cfg.get('aligner_encoder_train_steps', float('inf')) + self.dec_random_input_max = self.cfg.get('dec_random_input_max', self.num_all_tokens_per_codebook) def state_dict(self, destination=None, prefix='', keep_vars=False): @@ -260,7 +282,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): if any([substring in key for substring in keys_substrings_to_exclude]): del state_dict[key] return state_dict - + def load_state_dict(self, state_dict, strict=True): """ Modify load_state_dict so that we don't restore weights to _speaker_verification_model and _codec_model when @@ -370,7 +392,7 @@ def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_ +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ | Seq. Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ - + dec_out: (B, T', E) audio_codes_target: (B, C, T') targets_offset_by_one: bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive) @@ -416,7 +438,7 @@ def maskgit_create_random_mask(self, codes): B,C,T = codes.shape # get a uniform random vector uniformly sampled from [0,1) ## Todo does it need to be inclusive on the right? rand_values = torch.rand(B,T, device=codes.device) - # apply the cosine schedule + # apply the cosine schedule frac_masked = cosine_schedule(rand_values) # how many positions to mask n_masked = torch.ceil(frac_masked * C).long() # B,T @@ -431,19 +453,19 @@ def maskgit_create_random_mask(self, codes): # # mask the top n_masked positions # mask[b, perm[:n_masked[b,t]], t] = True # - # Create random permutations - random_permutations = torch.argsort(torch.rand(B, C, T, device=codes.device), dim=1) # (B, C, T) - # Create a mask tensor where each position indicates if it should be masked + # Create random permutations + random_permutations = torch.argsort(torch.rand(B, C, T, device=codes.device), dim=1) # (B, C, T) + # Create a mask tensor where each position indicates if it should be masked mask_indices = torch.arange(C, device=codes.device).view(1, C, 1) mask = mask_indices < n_masked.view(B, 1, T) # (B, C, T) # Apply the random permutations to the mask mask = torch.gather(mask, 1, random_permutations) - + return mask # (B, C, T) - + def maskgit_apply_random_mask(self, codes): # Randomly replaces some codes with the MASK_TOKEN with a proportion following the cosine schedule. - # Codes: (B, C, T) + # Codes: (B, C, T) mask = self.maskgit_create_random_mask(codes) ## replace some tokens with MASK_TOKEN codes_with_mask = torch.where(mask, self.mask_token_id, codes) @@ -454,7 +476,7 @@ def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=N Computes the audio codebook loss. Used by (1) The main Magpie-TTS transformer (2) The local transformer, for both autoregressive and MaskGit methods - + logits: (B, T', num_codebooks * num_tokens_per_codebook) audio_codes: (B, C, T') audio_codes_lens: (B,) @@ -471,8 +493,8 @@ def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=N if not loss_mask.any(): # Without this we were very rarely getting NaNs in the loss logging.warning("No tokens valid were found in compute_loss()!") - return torch.tensor(0.0, device=loss_mask.device), loss_mask - else: + return torch.tensor(0.0, device=loss_mask.device), loss_mask + else: # repeat loss mask for each codebook to simplify code below loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1) total_codebook_loss = None @@ -557,8 +579,8 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, # replace masks of the top-k confident codebooks with the the codes that were sampled for them unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices) codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) - - # build transformer input + + # build transformer input local_transformer_input = local_transformer_input_init for codebook_num in range(C): next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze(1) # (B, 1, 768) @@ -568,7 +590,7 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, # run transformer _mask = torch.ones(B, C+1, device=device) local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, C+1, d_local) - + # get logits logits = [] for codebook_num in range(C): @@ -591,7 +613,7 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, for item_idx in finished_items: logits[item_idx, :, :] = float('-inf') logits[item_idx, :, self.audio_eos_id] = 0.0 - + # sample with top-k logits_topk = torch.topk(logits, topk, dim=-1)[0] # (B, C, topk) indices_to_remove = logits < logits_topk[:, :, -1].unsqueeze(-1) # (B, C, num_audio_tokens_per_codebook) @@ -611,7 +633,7 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, # replace entries in sampled_codes with previously unmasked codebooks sampled_codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) # optionally: add noise to confidences here (as in token-critic paper) (not implemented) - + codes = sampled_codes assert not (codes == self.mask_token_id).any(), f"Codes contain mask tokens after completion of MaskGit sampling" if use_cfg: @@ -759,9 +781,9 @@ def log_val_audio_example( if is_wandb: wandb_audio_log[f"Audio/Example_{idx}"] = list() if context_audio_np is not None: - wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(context_audio_np, sample_rate=self.cfg.sample_rate, caption="context")) - wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(pred_audio_np, sample_rate=self.cfg.sample_rate, caption="prediction")) - wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(target_audio_np, sample_rate=self.cfg.sample_rate, caption="target")) + wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(context_audio_np, sample_rate=self.sample_rate, caption="context")) + wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(pred_audio_np, sample_rate=self.sample_rate, caption="prediction")) + wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(target_audio_np, sample_rate=self.sample_rate, caption="target")) if is_tb: if context_audio_np is not None: @@ -769,19 +791,19 @@ def log_val_audio_example( f'Example_{idx}/context', context_audio_np, global_step=self.global_step, - sample_rate=self.cfg.sample_rate, + sample_rate=self.sample_rate, ) logger.experiment.add_audio( f'Example_{idx}/prediction', pred_audio_np, global_step=self.global_step, - sample_rate=self.cfg.sample_rate, + sample_rate=self.sample_rate, ) logger.experiment.add_audio( f'Example_{idx}/target', target_audio_np, global_step=self.global_step, - sample_rate=self.cfg.sample_rate, + sample_rate=self.sample_rate, ) return wandb_audio_log @@ -789,13 +811,10 @@ def log_val_audio_example( def scale_prior(self, prior, global_step): if prior is None: return None - prior_end_step = self.cfg.prior_end_step - prior_scaledown_start_step = self.cfg.prior_scaledown_start_step - if global_step < prior_scaledown_start_step: + if global_step < self.prior_scaledown_start_step: return prior - elif global_step >= prior_end_step: - indefinite_prior_prob = self.cfg.get('indefinite_prior_prob', 0.0) - if random.random() < indefinite_prior_prob: + elif global_step >= self.prior_end_step: + if random.random() < self.indefinite_prior_prob: print("Using Prior") return prior else: @@ -807,17 +826,17 @@ def scale_prior(self, prior, global_step): residual = 1.0 - prior new_prior = prior + ( residual - * (global_step - prior_scaledown_start_step) - / (prior_end_step - prior_scaledown_start_step) + * (global_step - self.prior_scaledown_start_step) + / (self.prior_end_step - self.prior_scaledown_start_step) ) return new_prior - + def embed_text(self, text, text_mask): if self.use_bpe_char_tokenizer: text_embedded = self.cas_encoder(text, subword_mask=text_mask) else: text_embedded = self.text_embedding(text) - + return text_embedded def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_context_size=0): @@ -837,7 +856,7 @@ def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_co def prepare_context_tensors(self, batch): dec_context_size = 0 additional_decoder_input = None - addtional_decoder_mask = None + additional_decoder_mask = None context_audio_codes = None context_audio_codes_lens = None _attn_prior = None @@ -936,21 +955,20 @@ def prepare_context_tensors(self, batch): cond_mask = text_mask multi_encoder_mapping = None additional_decoder_input = context_embeddings - addtional_decoder_mask = context_mask + additional_decoder_mask = context_mask elif self.model_type == 'decoder_pretrain_synthesizer': pass else: raise ValueError(f"Unsupported model type {self.model_type}") - if attn_prior is not None and self.cfg.get('ctc_prior_layer_ids', None) is not None: - ctc_prior_layer_ids = self.cfg.ctc_prior_layer_ids + if attn_prior is not None and self.ctc_prior_layer_ids is not None: # Convert prior to a list of tensors, one for each layer # Set None for layers not in ctc_prior_layer_ids if self.model_type == 'multi_encoder_context_tts': - text_attn_prior = [attn_prior[0] if layer_idx in ctc_prior_layer_ids else None for layer_idx in range(self.cfg.decoder.n_layers) ] + text_attn_prior = [attn_prior[0] if layer_idx in self.ctc_prior_layer_ids else None for layer_idx in range(self.decoder.n_layers) ] attn_prior = [text_attn_prior, attn_prior[1]] else: - attn_prior = [attn_prior if layer_idx in ctc_prior_layer_ids else None for layer_idx in range(self.cfg.decoder.n_layers) ] + attn_prior = [attn_prior if layer_idx in self.ctc_prior_layer_ids else None for layer_idx in range(self.decoder.n_layers) ] return { 'beta_binomial_attn_prior': batch.get('align_prior_matrix', None), @@ -961,7 +979,7 @@ def prepare_context_tensors(self, batch): 'prior_used': _attn_prior is not None, 'multi_encoder_mapping': multi_encoder_mapping, 'additional_decoder_input': additional_decoder_input, - 'addtional_decoder_mask': addtional_decoder_mask, + 'additional_decoder_mask': additional_decoder_mask, 'dec_context_size': dec_context_size, 'text': text, 'text_embedded': text_embedded, @@ -1001,31 +1019,30 @@ def replace_beta_binomial_prior_with_binarized(self, attn_prior, aligner_attn_ha def get_binarized_prior_matrix(self, aligner_attn_soft, audio_lens, text_lens): # aligner_attn_soft B, 1, audio_timesteps, text_timesteps - if self.cfg.get('binarize_attn_method', 'argmax') == 'nemo_binarize': - binarize_repeat_audio_factor = self.cfg.get('binarize_repeat_audio_factor', 2) + if self.binarize_attn_method == 'nemo_binarize': + logging.info("Binarizing attention using nemo_binarize") + binarize_repeat_audio_factor = self.binarize_repeat_audio_factor aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave(binarize_repeat_audio_factor, dim=2) # B, 1, 2*audio_timesteps, text_timesteps aligner_attn_hard = binarize_attention_parallel(aligner_attn_soft_repeated, text_lens, audio_lens*binarize_repeat_audio_factor).squeeze(1) # B, 2*audio_timesteps, text_timesteps aligner_attn_hard = aligner_attn_hard[:, ::2, :] # B, audio_timesteps, text_timesteps - else: - print("Binaraizing attention using argmax") + elif self.binarize_attn_method == 'argmax': + logging.info("Binarizing attention using argmax") aligner_attn_hard = torch.argmax(aligner_attn_soft.squeeze(1), dim=-1) aligner_attn_hard = torch.nn.functional.one_hot(aligner_attn_hard, num_classes=aligner_attn_soft.size(-1)).float() + else: + raise ValueError(f"self.binarize_attn_method '{self.binarize_attn_method}' must be one of 'nemo_binarize' or 'argmax'.") - prior_future_decay = self.cfg.get('prior_future_decay', 1.0) - prior_past_decay = self.cfg.get('prior_past_decay', 1.0) - binarized_prior_epsilon = self.cfg.get('binarized_prior_epsilon', 0.0) - aligner_attn_hard_wider = aligner_attn_hard + binarized_prior_epsilon + aligner_attn_hard_wider = aligner_attn_hard + self.binarized_prior_epsilon - for future_timestep in range(self.cfg.get('prior_future_context', 1)): - decay_factor = prior_future_decay ** (future_timestep + 1) + for future_timestep in range(self.prior_future_context): + decay_factor = self.prior_future_decay ** (future_timestep + 1) aligner_attn_hard_wider[:,:,future_timestep+1:] += decay_factor * aligner_attn_hard[:,:,:-(future_timestep+1)] - for past_timestep in range(self.cfg.get('prior_past_context', 1)): - decay_factor = prior_past_decay ** (past_timestep + 1) + for past_timestep in range(self.prior_past_context): + decay_factor = self.prior_past_decay ** (past_timestep + 1) aligner_attn_hard_wider[:,:,:-past_timestep-1] += decay_factor * aligner_attn_hard[:,:,past_timestep+1:] aligner_attn_hard_wider = torch.clamp(aligner_attn_hard_wider, 0.0, 1.0) - return aligner_attn_hard_wider def prepare_dummy_cond_for_cfg(self, cond, cond_mask, additional_decoder_input, additional_dec_mask): @@ -1075,17 +1092,17 @@ def process_batch(self, batch, mode="train"): audio_codes_mask = get_mask_from_lengths(audio_codes_lens_input) use_cfg = ( - (self.cfg.get('cfg_unconditional_prob', 0.0) > 0.0) + (self.cfg_unconditional_prob > 0.0) and (mode == "train") and (context_tensors['cond'] is not None) ) - if use_cfg and torch.rand(1).item() < self.cfg.cfg_unconditional_prob: + if use_cfg and torch.rand(1).item() < self.cfg_unconditional_prob: cond, cond_mask, additional_decoder_input, additional_decoder_mask, attn_prior = ( self.prepare_dummy_cond_for_cfg( context_tensors['cond'], context_tensors['cond_mask'], context_tensors['additional_decoder_input'], - context_tensors['addtional_decoder_mask'], + context_tensors['additional_decoder_mask'], ) ) disable_alignment_loss = True @@ -1093,16 +1110,16 @@ def process_batch(self, batch, mode="train"): cond = context_tensors['cond'] cond_mask = context_tensors['cond_mask'] additional_decoder_input = context_tensors['additional_decoder_input'] - additional_decoder_mask = context_tensors['addtional_decoder_mask'] + additional_decoder_mask = context_tensors['additional_decoder_mask'] attn_prior = context_tensors['attn_prior'] if ( mode == "train" - and self.cfg.get('decoder_input_dropout_prob', 0.0) > 0.0 + and self.decoder_input_dropout_prob > 0.0 and torch.rand(1).item() < 0.5 ): # For some batches (half of them), replace decoder_input_dropout_prob of the timesteps with random tokens - max_codebook_val = self.cfg.get('dec_random_input_max', self.num_all_tokens_per_codebook) + max_codebook_val = self.dec_random_input_max # @pneekhara: Keeping dec_random_input_max configurable since num_all_tokens_per_codebook usually has padding tokens # which can cause errors when doing codes_to_audio for audio_codes_input. We are not currently calling codes_to_audio on # audio_codes_input so should not matter if we don't supply dec_random_input_max. @@ -1112,7 +1129,7 @@ def process_batch(self, batch, mode="train"): random_audio_tokens = random_audio_tokens * audio_codes_mask.unsqueeze(1) dec_dropout_mask = ( torch.rand((1, 1, audio_codes_input.size(2)), device=audio_codes_input.device) - > self.cfg.decoder_input_dropout_prob + > self.decoder_input_dropout_prob ) # timestep_mask is True for timesteps to be kept audio_codes_input = audio_codes_input * dec_dropout_mask + random_audio_tokens * (~dec_dropout_mask) @@ -1128,12 +1145,12 @@ def process_batch(self, batch, mode="train"): aligner_encoder_loss = None aligner_attn_soft = None aligner_attn_hard = None - if self.cfg.get('use_alignment_encoder', False) and not disable_alignment_loss: + if self.use_alignment_encoder and not disable_alignment_loss: aligner_prior = None - if self.cfg.get('use_prior_for_aligner', False): + if self.use_prior_for_aligner: aligner_prior = context_tensors['beta_binomial_attn_prior'] # Passing target audio embeddings to the alignment encoder - if self.global_step < self.cfg.get('aligner_encoder_train_steps', float('inf')): + if self.global_step < self.aligner_encoder_train_steps: aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder( queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T @@ -1158,8 +1175,7 @@ def process_batch(self, batch, mode="train"): aligner_attn_hard = self.get_binarized_prior_matrix( aligner_attn_soft, audio_codes_lens_input, context_tensors['text_lens'] ) - if (self.global_step > self.cfg.get('binarize_prior_after_step', 0)) and context_tensors['prior_used']: - print("Updating Prior") + if (self.global_step > self.binarize_prior_after_step) and context_tensors['prior_used']: attn_prior = self.replace_beta_binomial_prior_with_binarized(attn_prior, aligner_attn_hard) logits, attn_info, dec_out = self.forward( @@ -1176,18 +1192,16 @@ def process_batch(self, batch, mode="train"): logits = logits[:, dec_context_size:, :] # Remove the context audio embeddings from the logits codebook_loss, loss_mask = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target) - codebook_loss_scale = self.cfg.get('codebook_loss_scale', 1.0) alignment_loss = None - if self.cfg.alignment_loss_scale > 0.0 and not disable_alignment_loss: + if self.alignment_loss_scale > 0.0 and not disable_alignment_loss: text_lens = context_tensors['text_lens'] - ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers) - cross_attention_scores = [attn['cross_attn_probabilities'][1] for layer_idx, attn in enumerate(attn_info) if layer_idx in ctc_prior_layer_ids] + cross_attention_scores = [attn['cross_attn_probabilities'][1] for layer_idx, attn in enumerate(attn_info) if layer_idx in self.ctc_prior_layer_ids] alignment_loss = self.compute_alignment_loss( cross_attention_scores, text_lens, audio_codes_lens_target, dec_context_size ) - loss = codebook_loss_scale * codebook_loss + alignment_loss + loss = self.codebook_loss_scale * codebook_loss + alignment_loss else: - loss = codebook_loss_scale * codebook_loss + loss = self.codebook_loss_scale * codebook_loss local_transformer_loss = None local_transformer_logits = None @@ -1203,8 +1217,7 @@ def process_batch(self, batch, mode="train"): assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type" local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_target, targets_offset_by_one=False) local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target, audio_codes_lens_target, None) - local_transformer_loss_scale = self.cfg.get('local_transformer_loss_scale', 1.0) - loss = loss + local_transformer_loss_scale * local_transformer_loss + loss = loss + self.local_transformer_loss_scale * local_transformer_loss if aligner_encoder_loss is not None: loss = loss + aligner_encoder_loss @@ -1235,7 +1248,7 @@ def training_step(self, batch, batch_idx): loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] self.log('train/codebook_loss', codebook_loss, prog_bar=True, sync_dist=True) - if self.cfg.get('cfg_unconditional_prob', 0.0) == 0.0: + if self.cfg_unconditional_prob == 0.0: # Only log alignment loss when not using cfg to avoid sync issues when # alignment loss is None on some ranks alignment_loss = batch_output['alignment_loss'] @@ -1315,8 +1328,7 @@ def validation_step(self, batch, batch_idx): and len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1 ): # cross_attn_probabilities only returned when not using flash attention - ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers) - cross_attention_probs = [attn['cross_attn_probabilities'][0] for layer_idx, attn in enumerate(attn_info) if layer_idx in ctc_prior_layer_ids] + cross_attention_probs = [attn['cross_attn_probabilities'][0] for layer_idx, attn in enumerate(attn_info) if layer_idx in self.ctc_prior_layer_ids] wandb_log_dict.update( self.log_attention_probs( cross_attention_probs, @@ -1526,7 +1538,7 @@ def infer_batch( context_tensors['cond'], context_tensors['cond_mask'], context_tensors['additional_decoder_input'], - context_tensors['addtional_decoder_mask'], + context_tensors['additional_decoder_mask'], ) ) @@ -1548,13 +1560,13 @@ def infer_batch( _audio_codes_embedded = torch.cat( [context_tensors['additional_decoder_input'], audio_codes_embedded], dim=1 ) - _audio_codes_mask = torch.cat([context_tensors['addtional_decoder_mask'], audio_codes_mask], dim=1) + _audio_codes_mask = torch.cat([context_tensors['additional_decoder_mask'], audio_codes_mask], dim=1) else: _audio_codes_embedded = audio_codes_embedded _audio_codes_mask = audio_codes_mask if apply_prior_to_layers is not None: - attn_prior = [None for _ in range(self.cfg.decoder.n_layers)] + attn_prior = [None for _ in range(self.decoder.n_layers)] for layer_idx in apply_prior_to_layers: attn_prior[layer_idx] = _attn_prior else: @@ -1711,7 +1723,7 @@ def infer_batch( predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens) end_time = time.time() - total_audio_duration_generated = (predicted_audio_lens.max().item() * predicted_audio_lens.shape[0])/self._codec_model.sample_rate + total_audio_duration_generated = (predicted_audio_lens.max().item() * predicted_audio_lens.shape[0])/self.sample_rate rtf = total_audio_duration_generated / (end_time - start_time) rtf_metrics = { 'rtf': rtf, @@ -1762,7 +1774,7 @@ def test_step(self, batch, batch_idx): if is_wandb: log_dict = { f"test/predicted_audio": wandb.Audio( - predicted_audio_np, sample_rate=self.cfg.sample_rate, caption=f"Predicted Audio" + predicted_audio_np, sample_rate=self.sample_rate, caption=f"Predicted Audio" ), } logger.experiment.log(log_dict, step=item_idx) @@ -1772,7 +1784,7 @@ def test_step(self, batch, batch_idx): 'test/predicted_audio', predicted_audio_np, global_step=item_idx, - sample_rate=self.cfg.sample_rate, + sample_rate=self.sample_rate, ) # Save the predicted audio @@ -1781,7 +1793,7 @@ def test_step(self, batch, batch_idx): if not os.path.exists(audio_dir): os.makedirs(audio_dir) audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') - sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) + sf.write(audio_path, predicted_audio_np, self.sample_rate) def on_validation_epoch_end(self): collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean() @@ -1801,6 +1813,7 @@ def on_validation_epoch_end(self): def get_dataset(self, dataset_cfg, dataset_type): dataset = instantiate( dataset_cfg.dataset, + sample_rate=self.sample_rate, bos_id=self.bos_id, eos_id=self.eos_id, audio_bos_id=self.audio_bos_id, @@ -1827,7 +1840,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also # cfg is a classifier-free guidance. dataset = MagpieTTSLhotseDataset( - sample_rate=self.cfg.sample_rate, + sample_rate=self.sample_rate, volume_norm=dataset_cfg.volume_norm, codec_model_samples_per_frame=self.codec_model_samples_per_frame, codec_model_name=self.cfg.codec_model_name, diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index 2fe354df7765..a69846e97125 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -142,6 +142,7 @@ def __init__( d_model: int, p_dropout: float, is_causal: bool = True, + d_head: Optional[int] = None, ): """ Base Attention parent class. Users should not be instantiating this class, but rather use SelfAttention or @@ -154,10 +155,11 @@ def __init__( d_model (int): Dimension of the model. p_dropout (float): Dropout probability. is_causal (bool): Whether to use causal attention. Only supported when used in SelfAttention. + d_head (int): Head dimension. Defaults to d_model // n_heads. """ super().__init__() assert d_model % n_heads == 0, "d_model % n_head != 0" - self.d_head = d_model // n_heads + self.d_head = d_head if d_head is not None else d_model // n_heads self.n_heads = n_heads self.d_model = d_model self.scale = self.d_head**-0.5 @@ -450,7 +452,6 @@ def __init__( ) if self.has_xattn: - self.apply_norm_to_cond = apply_norm_to_cond self.norm_xattn_query = torch.nn.LayerNorm(d_model, bias=False) self.cross_attention = CrossAttention( n_heads=xa_n_heads, @@ -460,7 +461,8 @@ def __init__( make_prior_window_strict=make_prior_window_strict, ) - if self.apply_norm_to_cond: + self.norm_xattn_memory = torch.nn.Identity() + if apply_norm_to_cond: self.norm_xattn_memory = torch.nn.LayerNorm(xa_d_memory, bias=False) self.norm_pos_ff = torch.nn.LayerNorm(d_model, bias=False) @@ -521,7 +523,7 @@ def forward( if self.use_cache and self.cache['memory'] is not None: memory = self.cache['memory'] else: - memory = self.norm_xattn_memory(cond) if self.apply_norm_to_cond else cond + memory = self.norm_xattn_memory(cond) if self.use_cache: self.cache['memory'] = memory @@ -592,22 +594,20 @@ def __init__( raise ValueError("It requires that `xa_d_memory` and `xa_n_heads` are specified when `has_xattn` is True!") super().__init__() + self.n_layers = n_layers self.dropout = torch.nn.Dropout(p_dropout) self.p_dropout_out = p_dropout_out + self.dropout_out = torch.nn.Identity() if self.p_dropout_out > 0.0: self.dropout_out = torch.nn.Dropout(self.p_dropout_out) - else: - self.dropout_out = None - self.apply_norm_out = apply_norm_out - if self.apply_norm_out: + self.norm_out = torch.nn.Identity() + if apply_norm_out: self.norm_out = torch.nn.LayerNorm(d_model, bias=False) - else: - self.norm_out = None self.layers = torch.nn.ModuleList() - for _ in range(n_layers): + for _ in range(self.n_layers): self.layers.append( TransformerLayer( d_model=d_model, @@ -636,7 +636,7 @@ def __init__( self.apply(self._init_weights_gpt2) for name, param in self.named_parameters(): if 'o_net' in name and name.endswith('weight'): - torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * n_layers)) + torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers)) def reset_cache(self, use_cache=False): for layer in self.layers: @@ -728,10 +728,6 @@ def forward( if max_layer_idx is not None and idx == max_layer_idx: break - if self.norm_out is not None: - x = self.norm_out(x) - - if self.dropout_out is not None: - x = self.dropout_out(x) - + x = self.norm_out(x) + x = self.dropout_out(x) return {'output': x, 'attn_probabilities': attn_probabilities} From a445ab3b30d23791e3ca17889c093c5dd8c696a2 Mon Sep 17 00:00:00 2001 From: Jason Date: Mon, 9 Jun 2025 11:39:06 -0400 Subject: [PATCH 042/113] Re-enable CI (#13857) Signed-off-by: Jason --- .github/workflows/_build_container.yml | 89 ++++++++++++++++++++++++++ .github/workflows/cicd-main.yml | 9 +-- .github/workflows/copyright-check.yml | 4 +- 3 files changed, 92 insertions(+), 10 deletions(-) create mode 100644 .github/workflows/_build_container.yml diff --git a/.github/workflows/_build_container.yml b/.github/workflows/_build_container.yml new file mode 100644 index 000000000000..97d4a2a960c2 --- /dev/null +++ b/.github/workflows/_build_container.yml @@ -0,0 +1,89 @@ +name: ~Build container template +on: + workflow_call: + inputs: + image-name: + required: true + type: string + description: "The name of the image to build" + dockerfile: + required: true + type: string + runner: + required: false + default: self-hosted-azure-builder + type: string + description: "The runner to use for the build" + +jobs: + pre-flight: + runs-on: ubuntu-latest + outputs: + build_args: ${{ steps.manifest.outputs.BUILD_ARGS }} + cache-from: ${{ steps.cache_from.outputs.LAST_PRS }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Parse manifest.json + id: manifest + run: | + BUILD_ARGS=$(cat << EOF + BASE_IMAGE=$(cat requirements/manifest.json | jq -r '."ngc-pytorch"') + TRTLLM_REPO=$(cat requirements/manifest.json | jq -r '."vcs-dependencies"."trt-llm".repo') + TRTLLM_TAG=$(cat requirements/manifest.json | jq -r '."vcs-dependencies"."trt-llm".ref') + MLM_REPO=$(cat requirements/manifest.json | jq -r '."vcs-dependencies"."megatron-lm".repo') + MLM_TAG=$(cat requirements/manifest.json | jq -r '."vcs-dependencies"."megatron-lm".ref') + TE_REPO=$(cat requirements/manifest.json | jq -r '."vcs-dependencies".transformer_engine.repo') + TE_TAG=$(cat requirements/manifest.json | jq -r '."vcs-dependencies".transformer_engine.ref') + APEX_REPO=$(cat requirements/manifest.json | jq -r '."vcs-dependencies".apex.repo') + APEX_TAG=$(cat requirements/manifest.json | jq -r '."vcs-dependencies".apex.ref') + EOF + ) + + echo "BUILD_ARGS<> $GITHUB_OUTPUT + echo "$BUILD_ARGS" >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + + - name: Get last merged PR + id: cache_from + env: + GH_TOKEN: ${{ github.token }} + run: | + LAST_PRS=$(gh api graphql -f query=' + query { + repository(owner: "NVIDIA", name: "NeMo") { + pullRequests(states: MERGED, first: 100, orderBy: {field: UPDATED_AT, direction: DESC}) { + nodes { + number + } + } + } + }' | jq -r '.data.repository.pullRequests.nodes[].number' | while read -r number; do + echo "nemoci.azurecr.io/${{ inputs.image-name }}-buildcache:$number" + done) + + echo "LAST_PRS<> $GITHUB_OUTPUT + echo "$LAST_PRS" >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + + build: + uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_build_container.yml@v0.27.0 + needs: [pre-flight] + with: + image-name: ${{ inputs.image-name }} + dockerfile: ${{ inputs.dockerfile }} + image-label: nemo-core + build-args: | + IMAGE_LABEL=nemo-core + NEMO_TAG=${{ github.sha }} + NEMO_REPO=https://github.com/NVIDIA/NeMo + PR_NUMBER=${{ github.event.pull_request.number || 0 }} + ${{ needs.pre-flight.outputs.build_args }} + prune-filter-timerange: 24h + use-inline-cache: false + cache-from: | + nemoci.azurecr.io/${{ inputs.image-name }}-buildcache:main + nemoci.azurecr.io/${{ inputs.image-name }}-buildcache:${{ github.event.pull_request.number || 0 }} + ${{ needs.pre-flight.outputs.cache-from }} + runner: ${{ inputs.runner }} diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 86e7f16ea195..fdf0b0c94876 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -106,18 +106,11 @@ jobs: cicd-test-container-build: if: ${{ needs.pre-flight.outputs.test_to_run != '[]' }} - uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_build_container.yml@v0.14.0 + uses: ./.github/workflows/_build_container.yml needs: [pre-flight, code-linting] with: image-name: nemo_container dockerfile: Dockerfile.ci - image-label: nemo-core - build-args: | - IMAGE_LABEL=nemo-core - NEMO_TAG=${{ github.sha }} - NEMO_REPO=https://github.com/NVIDIA/NeMo - ${{ needs.pre-flight.outputs.BUILD_ARGS }} - prune-filter-timerange: 24h cicd-import-tests: if: ${{ needs.pre-flight.outputs.test_to_run != '[]' }} diff --git a/.github/workflows/copyright-check.yml b/.github/workflows/copyright-check.yml index ebd35c51dc44..803889bb62c4 100644 --- a/.github/workflows/copyright-check.yml +++ b/.github/workflows/copyright-check.yml @@ -14,9 +14,9 @@ name: Copyright check -on: +on: pull_request: jobs: copyright-check: - uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_copyright_check.yml@v0.2.0 \ No newline at end of file + uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_copyright_check.yml@v0.2.0 \ No newline at end of file From f116e9ea8823f95b820456ca25c1a2de4dbfb2a8 Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Mon, 9 Jun 2025 09:25:48 -0700 Subject: [PATCH 043/113] Fix num_codebooks in Group RVQ (#13841) Signed-off-by: Ryan Co-authored-by: Jason --- nemo/collections/tts/modules/encodec_modules.py | 3 ++- tests/collections/tts/modules/test_audio_codec_modules.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index a1c552e21074..377f5671161e 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -847,6 +847,7 @@ class GroupResidualVectorQuantizer(VectorQuantizerBase): def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwargs): super().__init__() + self._num_codebooks = num_codebooks self.num_groups = num_groups self.codebook_dim = codebook_dim @@ -870,7 +871,7 @@ def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwa @property def num_codebooks(self): """Returns the number of codebooks.""" - return self.num_groups + return self._num_codebooks @property def codebook_size(self): diff --git a/tests/collections/tts/modules/test_audio_codec_modules.py b/tests/collections/tts/modules/test_audio_codec_modules.py index 325fde5b3bf7..e1429df4fb70 100644 --- a/tests/collections/tts/modules/test_audio_codec_modules.py +++ b/tests/collections/tts/modules/test_audio_codec_modules.py @@ -240,7 +240,6 @@ def test_rvq_eval(self, num_codebooks: int): torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') - @pytest.mark.pleasefixme @pytest.mark.unit @pytest.mark.parametrize('num_groups', [1, 2, 4]) @pytest.mark.parametrize('num_codebooks', [1, 4]) From 6f87a6946bc03c7a5bfc27e641bb58f645ddf418 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Mon, 9 Jun 2025 16:26:03 -0700 Subject: [PATCH 044/113] [tranformer_core][magpietts_config] added support to override xattn head dim by xa_d_head (#13856) * [tranformer_core][magpietts_config] added support to override cross attention head dim by xa_d_head Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [tranformer_core][magpietts_config] removed changes in yaml configs. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Jason --- nemo/collections/tts/modules/transformer_2501.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index a69846e97125..59b1986ac348 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -352,6 +352,7 @@ def __init__( d_memory: int, p_dropout: float, make_prior_window_strict: bool = False, + d_head: Optional[int] = None, ): """ Implements CrossAttention. See parent class for forward implementation. Must be non-causal. @@ -362,15 +363,15 @@ def __init__( d_memory (int): Dimension of the conditioning / cross-attention input. p_dropout (float): Dropout probability. make_prior_window_strict (bool): Make attention scores lowest where prior is zero. + d_head (int): Head dimension. if None, defaults to d_model // n_heads in parent class. """ super().__init__( n_heads=n_heads, d_model=d_model, p_dropout=p_dropout, is_causal=False, + d_head=d_head, ) - if d_memory is None: - raise ValueError("d_memory must be provided for cross-attention") self.q_net = torch.nn.Linear(d_model, n_heads * self.d_head, bias=False) self.kv_net = torch.nn.Linear(d_memory, 2 * n_heads * self.d_head, bias=False) self.make_prior_window_strict = make_prior_window_strict @@ -416,6 +417,7 @@ def __init__( has_xattn: bool, xa_d_memory: Optional[int] = None, xa_n_heads: Optional[int] = None, + xa_d_head: Optional[int] = None, is_causal: bool = True, apply_norm_to_cond: bool = True, max_length_causal_mask: int = 4096, @@ -433,6 +435,7 @@ def __init__( has_xattn : Whether to use cross attention xa_d_memory : Hidden dimension for cross attention xa_n_heads : Number of attention heads used in cross attention + xa_d_head : Head dimension for cross attention. if None, defaults to d_model // xa_n_heads in Attention class. is_causal : Whether to use causal attention apply_norm_to_cond : Whether to apply normalization to conditioning tensor max_length_causal_mask : Maximum length of causal mask @@ -459,6 +462,7 @@ def __init__( d_memory=xa_d_memory, p_dropout=p_dropout, make_prior_window_strict=make_prior_window_strict, + d_head=xa_d_head, ) self.norm_xattn_memory = torch.nn.Identity() @@ -559,6 +563,7 @@ def __init__( has_xattn: bool = False, xa_d_memory: Optional[int] = None, xa_n_heads: Optional[int] = None, + xa_d_head: Optional[int] = None, is_causal: bool = True, apply_norm_to_cond: bool = True, apply_norm_out: bool = False, @@ -581,6 +586,7 @@ def __init__( has_xattn : Whether to use cross attention xa_d_memory : Hidden dimension for cross attention; required if has_xattn is True xa_n_heads : Number of attention heads used in cross attention; required if has_xattn is True + xa_d_head : Head dimension for cross attention. if None, defaults to d_model // xa_n_heads in Attention class. is_causal : Whether to make attention and the convolution feedforward networks causal. apply_norm_to_cond : Whether to apply normalization to conditioning tensor; conditioning tensor being the input to the memory part of cross-attention. @@ -618,6 +624,7 @@ def __init__( has_xattn=has_xattn, xa_d_memory=xa_d_memory, xa_n_heads=xa_n_heads, + xa_d_head=xa_d_head, is_causal=is_causal, apply_norm_to_cond=apply_norm_to_cond, max_length_causal_mask=max_length_causal_mask, From 024a168397bd117909b19bcd7d8765756099ee7d Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Tue, 10 Jun 2025 08:43:25 -0700 Subject: [PATCH 045/113] [magpietts][wandb] fixed unexpected panel displaying the val_loss in wandb UI. Previously val_loss was displayed in the pannel 'Charts'. Now it is displayed in the pannel 'val', the same group as other val metrics. Also fixed typos. (#13848) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- nemo/collections/tts/models/magpietts.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 02ca887ed16e..b5bf878d3a61 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -113,7 +113,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.use_text_conditioning_encoder = cfg.get('use_text_conditioning_encoder', False) # TODO @xueyang: both tokenizers are only used to get some token ids. We - # should kill them to save a small mount of mem resources since dataloader will initialize them + # should kill them to save a small amount of mem resources since dataloader will initialize them # again after the worker processes are spawned. self.tokenizer, self.text_conditioning_tokenizer = setup_tokenizers( all_tokenizers_config=cfg.text_tokenizers, @@ -1661,7 +1661,7 @@ def infer_batch( ) finished_items = {k: v for k, v in finished_texts_counter.items() if v >= 20} # Items that have been close to the end for atleast 20 timesteps - unifinished_items = {k: v for k, v in unfinished_texts.items() if v} + unfinished_items = {k: v for k, v in unfinished_texts.items() if v} all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) if use_local_transformer_for_inference: @@ -1671,7 +1671,7 @@ def infer_batch( dec_output=dec_out[:,-1,:], temperature=temperature, topk=topk, - unfinished_items=unifinished_items, + unfinished_items=unfinished_items, finished_items=finished_items, use_cfg=use_cfg, cfg_scale=cfg_scale @@ -1681,7 +1681,7 @@ def infer_batch( dec_output=dec_out[:,-1,:], temperature=temperature, topk=topk, - unfinished_items=unifinished_items, + unfinished_items=unfinished_items, finished_items=finished_items, use_cfg=use_cfg, cfg_scale=cfg_scale, @@ -1693,8 +1693,8 @@ def infer_batch( all_codes_next_argmax = audio_codes_next else: # Parallel sampling from all codebooks - audio_codes_next = self.sample_codes_from_logits(all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) - all_codes_next_argmax = self.sample_codes_from_logits(all_code_logits_t, temperature=0.01, unfinished_items=unifinished_items, finished_items=finished_items) # (B, num_codebooks) + audio_codes_next = self.sample_codes_from_logits(all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unfinished_items, finished_items=finished_items) # (B, num_codebooks) + all_codes_next_argmax = self.sample_codes_from_logits(all_code_logits_t, temperature=0.01, unfinished_items=unfinished_items, finished_items=finished_items) # (B, num_codebooks) for item_idx in range(all_codes_next_argmax.size(0)): if item_idx not in end_indices: @@ -1801,7 +1801,10 @@ def on_validation_epoch_end(self): val_codebook_loss = collect("val_codebook_loss") val_alignment_loss = collect("val_alignment_loss") val_aligner_encoder_loss = collect("val_aligner_encoder_loss") - self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) + # log val_loss in the same group as the other val metrics. + self.log("val/loss", val_loss, prog_bar=True, sync_dist=True) + # ensure val_loss is available for epoch-level checkpointing and filename generation without cluttering wandb logs. + self.log("val_loss", val_loss, prog_bar=False, sync_dist=True, on_step=False, on_epoch=True, logger=False, enable_graph=False) self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True) From 926f1c33ff66ba28e9bd67763db3688db3b744c4 Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 11 Jun 2025 10:41:17 -0400 Subject: [PATCH 046/113] Add more CI tests for Magpie branch (#13831) * refactor 1: typos, yaml updates, code changes Signed-off-by: Jason * add new test Signed-off-by: Jason * add the new tests Signed-off-by: Jason * pull in latest container build from main; hopefully it works without merging the rest of main Signed-off-by: Jason * update repo location Signed-off-by: Jason * try removing wait-in-queue from needs Signed-off-by: Jason * fix tests Signed-off-by: Jason * Update nemo/collections/tts/data/text_to_speech_dataset.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * undo dataset change Signed-off-by: Jason * remove infer and eval tests for now Signed-off-by: Jason * fix(huggingface-hub): allow offline mode (#11901) * fix(huggingface-hub): allow offline mode Allow to reuse already cached models without the need for a connection to the internet or to the HuggingFace model hub. Signed-off-by: Strobel Maximilian (IFAG PSS SIS SCE ACM) * Apply isort and black reformatting Signed-off-by: maxstrobel * flake8 common.py Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> * flake common.py Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> --------- Signed-off-by: Strobel Maximilian (IFAG PSS SIS SCE ACM) Signed-off-by: maxstrobel Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: maxstrobel Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: Nithin Rao * fix L2 tests to use new json, cleanup disk usage, and add target checks Signed-off-by: Jason * undo comment Signed-off-by: Jason --------- Signed-off-by: Jason Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Strobel Maximilian (IFAG PSS SIS SCE ACM) Signed-off-by: maxstrobel Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Max Strobel Co-authored-by: maxstrobel Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: Nithin Rao --- .github/workflows/cicd-main.yml | 38 +++++++++++++++++-- examples/tts/conf/magpietts/magpietts_en.yaml | 1 - nemo/core/classes/common.py | 19 ++++++---- scripts/magpietts/evalset_config.py | 5 +++ scripts/magpietts/infer_and_evaluate.py | 36 ++++++++++++++---- ...Fast_dev_runs_Magpietts_DecoderContext.sh} | 4 +- ...TS_Fast_dev_runs_Magpietts_MultiEncoder.sh | 33 ++++++++++++++++ ...TS_InferEvaluate_Magpietts_SeenSpeakers.sh | 29 ++++++++++++++ ...L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh | 29 ++++++++++++++ 9 files changed, 171 insertions(+), 23 deletions(-) rename tests/functional_tests/{L2_TTS_Fast_dev_runs_Magpietts_config1.sh => L2_TTS_Fast_dev_runs_Magpietts_DecoderContext.sh} (96%) create mode 100644 tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder.sh create mode 100644 tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh create mode 100644 tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index fdf0b0c94876..120a15e12bac 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -1162,14 +1162,41 @@ jobs: # SCRIPT: |- # RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_1_Hifigan.sh - L2_TTS_Fast_dev_runs_Magpietts_config1: + L2_TTS_Fast_dev_runs_Magpietts_DecoderContext: needs: [pre-flight, cicd-test-container-build] uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_Magpietts_config1') + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_Magpietts_DecoderContext') with: RUNNER: self-hosted-azure SCRIPT: |- - RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_config1.sh + RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_DecoderContext.sh + + L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder: + needs: [pre-flight, cicd-test-container-build] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder') + with: + RUNNER: self-hosted-azure + SCRIPT: |- + RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder.sh + + L2_TTS_InferEvaluate_Magpietts_ZeroShot: + needs: [pre-flight, cicd-test-container-build] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_InferEvaluate_Magpietts_ZeroShot') + with: + RUNNER: self-hosted-azure + SCRIPT: |- + RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh + + L2_TTS_InferEvaluate_Magpietts_SeenSpeakers: + needs: [pre-flight, cicd-test-container-build] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_InferEvaluate_Magpietts_SeenSpeakers') + with: + RUNNER: self-hosted-azure + SCRIPT: |- + RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh # # L2: NeRF # # L2_NeRF_DreamFusion: @@ -1816,7 +1843,10 @@ jobs: - L0_Unit_Tests_CPU_Lightning - L0_Unit_Tests_CPU_Others - - L2_TTS_Fast_dev_runs_Magpietts_config1 + - L2_TTS_Fast_dev_runs_Magpietts_DecoderContext + - L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder + - L2_TTS_InferEvaluate_Magpietts_ZeroShot + - L2_TTS_InferEvaluate_Magpietts_SeenSpeakers # - ASR_dev_run_Speech_to_Text # - ASR_dev_run_Speech_to_Text_WPE_CitriNet diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index 25dd489a2e2e..cb5842cdf834 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -29,7 +29,6 @@ model: codecmodel_path: ??? max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} cfg_unconditional_prob: 0.1 # Alignment encoder parameters, to binarize the prior diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 6b34b3067765..78cb7dfc76aa 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -14,6 +14,7 @@ """Interfaces common to all Neural Modules and Models.""" +from __future__ import annotations import copy import hashlib import inspect @@ -26,14 +27,14 @@ from enum import Enum from functools import total_ordering from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import hydra import torch import wrapt -from huggingface_hub import HfApi +from huggingface_hub import _CACHED_NO_EXIST, HfApi from huggingface_hub import get_token as get_hf_token -from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub import hf_hub_download, snapshot_download, try_to_load_from_cache from omegaconf import DictConfig, OmegaConf import nemo @@ -853,6 +854,12 @@ def _get_hf_hub_pretrained_model_info(cls, model_name: str, refresh_cache: bool # Resolve the model name without origin for filename resolved_model_filename = model_name.split("/")[-1] + '.nemo' + # Try to take from cache first - if not fallback to options below + if not refresh_cache: + path = try_to_load_from_cache(repo_id=model_name, filename=resolved_model_filename) + if path is not None and path is not _CACHED_NO_EXIST: + return cls, path + # Check if api token exists, use if it does hf_token = get_hf_token() @@ -918,11 +925,7 @@ def _get_hf_hub_pretrained_model_info(cls, model_name: str, refresh_cache: bool token=hf_token, ) - # Cannot pre-resolve the specific class without double instantiation (first for config, second for model params) - # Default to current class, and perform basic class path resolution (handled via restore_from() + target class) - class_ = cls - - return class_, path + return cls, path def generate_model_card( self, type: str = "hf", template: str = None, template_kwargs: Optional[Dict[str, str]] = None diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index 05b9fd66da4e..c7ad23724aaa 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -282,4 +282,9 @@ 'audio_dir' : '/mnt/drive1/data/LibriTTS/', 'feature_dir' : None, }, + 'an4_val_ci': { + 'manifest_path' : '/home/TestData/an4_dataset/an4_val_context_v1.json', + 'audio_dir' : '/', + 'feature_dir' : None, + }, } \ No newline at end of file diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 52b55afc8192..59ae5c0e4f09 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -117,7 +117,8 @@ def run_inference( confidence_level=0.95, use_local_transformer=False, maskgit_n_steps=3, - legacy_codebooks=False + legacy_codebooks=False, + clean_up_disk=False, ): # Load model if hparams_file is not None: @@ -256,7 +257,7 @@ def run_inference( use_local_transformer_for_inference=use_local_transformer, maskgit_n_steps=maskgit_n_steps ) - + all_rtf_metrics.append(rtf_metrics) et = time.time() print(f"Time taken for inference: {et-st}", predicted_audio.size()) @@ -331,7 +332,16 @@ def run_inference( with open(all_experiment_csv_with_ci, "a") as f: f.write(f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']},{metrics_mean_ci['frechet_codec_distance']}\n") print(f"Wrote metrics with CI for {checkpoint_name} and {dataset} to {all_experiment_csv_with_ci}") - + + measurements = [m['ssim_pred_context_avg'] for m in metrics_n_repeated] + ssim = np.mean(measurements) + measurements = [m['cer_cumulative'] for m in metrics_n_repeated] + cer = np.mean(measurements) + + if clean_up_disk: + shutil.rmtree(out_dir) + return cer, ssim + def main(): parser = argparse.ArgumentParser(description='Experiment Evaluation') @@ -364,6 +374,9 @@ def main(): parser.add_argument('--num_repeats', type=int, default=1) parser.add_argument('--confidence_level', type=float, default=0.95) parser.add_argument('--legacy_codebooks', action='store_true') + parser.add_argument('--clean_up_disk', action='store_true') + parser.add_argument('--cer_target', type=float, default=None) + parser.add_argument('--ssim_target', type=float, default=None) args = parser.parse_args() estimate_alignment_from_layers = None @@ -380,7 +393,7 @@ def main(): print("Running inference for checkpoint files: ", checkpoint_files) assert len(hparam_files) == len(checkpoint_files), "Number of hparams files and checkpoint files should be the same." for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): - run_inference( + cer, ssim = run_inference( hparams_file=hparams_file, checkpoint_file=checkpoint_file, nemo_file=None, @@ -404,13 +417,14 @@ def main(): confidence_level=args.confidence_level, use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, - legacy_codebooks=args.legacy_codebooks + legacy_codebooks=args.legacy_codebooks, + clean_up_disk=args.clean_up_disk ) return elif (args.nemo_file is not None): nemo_file = args.nemo_file print("Running inference for nemo file: ", nemo_file) - run_inference( + cer, ssim = run_inference( hparams_file=None, checkpoint_file=None, nemo_file=nemo_file, @@ -434,7 +448,8 @@ def main(): confidence_level=args.confidence_level, use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, - legacy_codebooks=args.legacy_codebooks + legacy_codebooks=args.legacy_codebooks, + clean_up_disk=args.clean_up_disk ) else: BASE_EXP_DIR = args.base_exp_dir @@ -497,8 +512,13 @@ def main(): confidence_level=args.confidence_level, use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, - legacy_codebooks=args.legacy_codebooks + legacy_codebooks=args.legacy_codebooks, + clean_up_disk=args.clean_up_disk ) + if cer > float(args.cer_target): + raise ValueError() + if ssim < float(args.ssim_target): + raise ValueError() if __name__ == '__main__': diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_config1.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_DecoderContext.sh similarity index 96% rename from tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_config1.sh rename to tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_DecoderContext.sh index a6aaa2013350..b1c7967fa273 100644 --- a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_config1.sh +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_DecoderContext.sh @@ -13,11 +13,11 @@ # limitations under the License. coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts.py \ --config-name magpietts_dc_en \ - +train_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_train.json" \ + +train_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_train_context_v1.json" \ +train_ds_meta.an4.audio_dir="/" \ +train_ds_meta.an4.tokenizer_names="[english_phoneme]" \ +train_ds_meta.an4.feature_dir=null \ - +val_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val.json" \ + +val_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val_context_v1.json" \ +val_ds_meta.an4.audio_dir="/" \ +val_ds_meta.an4.tokenizer_names="[english_phoneme]" \ +val_ds_meta.an4.feature_dir=null \ diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder.sh new file mode 100644 index 000000000000..861ad3fdb92d --- /dev/null +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder.sh @@ -0,0 +1,33 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts.py \ + --config-name magpietts_en \ + +train_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_train_context_v1.json" \ + +train_ds_meta.an4.audio_dir="/" \ + +train_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +train_ds_meta.an4.feature_dir=null \ + +val_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val_context_v1.json" \ + +val_ds_meta.an4.audio_dir="/" \ + +val_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +val_ds_meta.an4.feature_dir=null \ + max_epochs=1 \ + batch_size=4 \ + model.codecmodel_path="/home/TestData/tts/21fps_causal_codecmodel.nemo" \ + trainer.devices="[0]" \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.strategy=auto \ + model.train_ds.dataloader_params.num_workers=0 \ + model.validation_ds.dataloader_params.num_workers=0 \ + ~trainer.check_val_every_n_epoch diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh new file mode 100644 index 000000000000..6a91252decfd --- /dev/null +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh @@ -0,0 +1,29 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo scripts/magpietts/infer_and_evaluate.py \ + --codecmodel_path /home/TestData/tts/AudioCodec_21Hz_no_eliz_without_wavlm_disc.nemo \ + --datasets an4_val_ci \ + --out_dir ./mp_ss_0 \ + --batch_size 4 \ + --use_cfg \ + --cfg_scale 2.5 \ + --num_repeats 1 \ + --temperature 0.6 \ + --hparams_files /home/TestData/tts/2506_SeenSpeaker/hparams.yaml \ + --checkpoint_files /home/TestData/tts/2506_SeenSpeaker/T5TTS--val_loss=0.3125-epoch=8.ckpt \ + --legacy_codebooks \ + --apply_attention_prior \ + --clean_up_disk \ + --cer_target 0.3 \ + --ssim_target 0.5 diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh new file mode 100644 index 000000000000..098b04e7a637 --- /dev/null +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh @@ -0,0 +1,29 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo scripts/magpietts/infer_and_evaluate.py \ + --codecmodel_path /home/TestData/tts/AudioCodec_21Hz_no_eliz_without_wavlm_disc.nemo \ + --datasets an4_val_ci \ + --out_dir ./mp_zs_0 \ + --batch_size 4 \ + --use_cfg \ + --cfg_scale 2.5 \ + --num_repeats 1 \ + --temperature 0.6 \ + --hparams_files /home/TestData/tts/2506_ZeroShot/lrhm_short_yt_prioralways_alignement_0.002_priorscale_0.1.yaml \ + --checkpoint_files /home/TestData/tts/2506_ZeroShot/dpo-T5TTS--val_loss=0.4513-epoch=3.ckpt \ + --legacy_codebooks \ + --apply_attention_prior \ + --clean_up_disk \ + --cer_target 0.1 \ + --ssim_target 0.7 From 2bba9765045d380513d668d6c7c3ca37db91ca55 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Thu, 12 Jun 2025 16:45:05 -0700 Subject: [PATCH 047/113] FCD Metric bugfix: handle empty codes update (#13906) We've observed instances where the model generates zero-length outputs which can then be fed into the FCD calculation. Update the metric to handle this case without crashing and add a unit test. --- nemo/collections/tts/modules/fcd_metric.py | 7 ++++--- tests/collections/tts/modules/test_fcd_metric.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/nemo/collections/tts/modules/fcd_metric.py b/nemo/collections/tts/modules/fcd_metric.py index 5d703c6ddfbd..9c2c4d31b540 100644 --- a/nemo/collections/tts/modules/fcd_metric.py +++ b/nemo/collections/tts/modules/fcd_metric.py @@ -23,10 +23,7 @@ be useful to explore). """ -import warnings - import torch -import torch.nn.functional as F from torch import nn, Tensor from torchmetrics import Metric import numpy as np @@ -171,6 +168,10 @@ def update(self, codes: Tensor, codes_len: Tensor, is_real: bool): """ assert codes.ndim == 3 + if codes.numel() == 0: + logging.warning(f"\nFCD metric received an empty batch of codes - skipping update\n") + return + # Dequantize the codes to a continuous representation embeddings = self.model.codes_to_embedding( codes, codes_len diff --git a/tests/collections/tts/modules/test_fcd_metric.py b/tests/collections/tts/modules/test_fcd_metric.py index 6c2807716b05..bba65cf7b821 100644 --- a/tests/collections/tts/modules/test_fcd_metric.py +++ b/tests/collections/tts/modules/test_fcd_metric.py @@ -29,7 +29,7 @@ def codec(self, device, scope="session"): return AudioCodecModel.from_pretrained("nvidia/low-frame-rate-speech-codec-22khz").to(device) @pytest.fixture - def metric(self, codec, device): + def metric(self, codec, device): codec_feature_dim = codec.vector_quantizer.codebook_dim return FrechetCodecDistance(codec=codec, feature_dim=codec_feature_dim).to(device) @@ -101,3 +101,12 @@ def test_update_from_audio_file(self, metric): fcd = metric.compute() assert isinstance(fcd, torch.Tensor) assert fcd > 0, f"FCD value is {fcd} but should be positive given that we tested different audio files" + + @pytest.mark.unit + def test_empty_codes_update(self, metric, device): + """Test that the FCD metric doesn't crash when provided with empty codes.""" + B, C, T = 1, 0, 100 + codes = torch.ones(B, C, T, device=device) + codes_len = T * torch.ones(B, device=device) + # if it crashes PyTest will report it + metric.update(codes, codes_len, is_real=True) From 7be8621087a149fd96895a112e30d2362d337082 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Mon, 16 Jun 2025 13:56:48 -0700 Subject: [PATCH 048/113] [magpietts][lhotse_v2] make model training recipe adapt to the latest v2 datasets. (#13879) * [magpietts][lhotse_v2] make model training recipe adapt to the latest changes of lhotse v2 datasets. - change the codes names in cuts into target_codes and context_codes. - remove codec_model_name in yaml config because we've added the codec name in input_cfg.yaml file. - fixed bugs that spec_len has extended with bos and eos resulting in unexpected 2 more tokens. Prioritize normalized_text over text if available. - added cuts sampling filters in terms of max cer and min context speaker similarity. - add sample_rate overrides the same as codec model's. Without overrides, the audio will be resampled into 16k by default. - bugfix: ensure start and duration params equal to None in TemporalArray.load(start, duration) since the audio codes were extracted based on segmented audio. - downlevel to logging.warning since they were verbosed for every step. - explicitly call wandb.finish() after trainer.fit or trainer.test complete to avoide hangs when debugging with num_worker=0. - extract numerical value from tensors. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * bugfix: replace cfg.model with model_cfg Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * bugfix: unit tests Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../magpietts/magpietts_lhotse_dc_en.yaml | 5 +- examples/tts/magpietts.py | 40 ++++++++---- .../common/data/lhotse/dataloader.py | 11 ++++ .../common/data/lhotse/sampling.py | 28 ++++++++ .../tts/data/text_to_speech_dataset_lhotse.py | 65 ++++++++----------- nemo/collections/tts/models/magpietts.py | 35 ++++++---- 6 files changed, 124 insertions(+), 60 deletions(-) diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index b8f19208149a..bac847091df4 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -8,7 +8,6 @@ model: use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. context_duration_min: 5.0 context_duration_max: 5.0 - codec_model_name: "21fpsCausalDecoder" load_cached_codes_if_available: true prior_scaling_factor: 0.5 prior_end_step: 12_000 @@ -64,6 +63,8 @@ model: dataset: min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 batch_duration : ??? # in seconds. Adjust based on your GPU memory. quadratic_duration: ${quadratic_duration} use_bucketing: true @@ -90,6 +91,8 @@ model: dataset: min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. quadratic_duration: ${quadratic_duration} use_bucketing: false diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index a40438f8cc28..5681600f8617 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -47,31 +47,49 @@ def main(cfg): trainer.callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='step', log_weight_decay=True)) exp_manager(trainer, cfg.get("exp_manager", None)) - if cfg.get('mode', 'train') == 'train': + mode = cfg.get('mode', 'train') + if mode == 'train': model = MagpieTTSModel(cfg=cfg.model, trainer=trainer) - elif cfg.get('mode', 'train') == 'dpo_train': + elif mode == 'dpo_train': model_cfg = cfg.model with open_dict(model_cfg): model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt model = MagpieTTSModelOfflinePO(cfg=model_cfg, trainer=trainer) - elif cfg.get('mode', 'train') == 'onlinepo_train': + elif mode == 'onlinepo_train': model_cfg = cfg.model with open_dict(model_cfg): model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt model = MagpieTTSModelOnlinePO(cfg=model_cfg, trainer=trainer) - elif cfg.get('mode', 'train') == 'test': + elif mode == 'test': model = MagpieTTSModelOfflinePODataGen(cfg=cfg.model, trainer=trainer) else: - raise NotImplementedError(f"Only train, dpo_train and test modes are supported. Got {cfg.mode}") + raise NotImplementedError( + f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}" + ) model.maybe_init_from_pretrained_checkpoint(cfg=cfg) - if cfg.get('mode', 'train') in ['train', 'dpo_train', 'onlinepo_train']: - trainer.fit(model) - elif cfg.get('mode', 'train') == 'test': - trainer.test(model) - else: - raise NotImplementedError(f"Only train and test modes are supported. Got {cfg.mode}") + try: + if mode in ['train', 'dpo_train', 'onlinepo_train']: + logging.info("Starting training...") + trainer.fit(model) + elif mode == 'test': + logging.info("Starting testing...") + trainer.test(model) + else: + raise NotImplementedError(f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}") + logging.info("Training/testing completed successfully.") + finally: + # Ensure WandB completes all uploads before Python thread shutdown + # Critical when num_workers=0 during debugging - the main process can become + # overwhelmed and fail to properly coordinate with WandB's background threads + try: + import wandb + if wandb.run is not None: + logging.info("Finishing WandB run to prevent threading shutdown hang...") + wandb.finish() + except Exception as e: + logging.warning(f"Error finishing WandB: {e}") if __name__ == '__main__': diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index fb3c97bc9bdb..20beff3c3817 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -45,6 +45,8 @@ ) from nemo.collections.common.data.lhotse.sampling import ( BucketingFilter, + CERFilter, + ContextSpeakerSimilarityFilter, DurationFilter, FixedBucketBatchSizeConstraint2D, MultimodalFixedBucketBatchSizeConstraint2D, @@ -130,6 +132,10 @@ class LhotseDataLoadingConfig: min_tpt: int = -1 # allowed tokens per token (text-only) max_tpt: Any = float("inf") # float | list[float] + # 2.3 Filters on CER and/or cosine speaker similarity of the context audio serving for TTS use cases. + max_cer: float | None = float("inf") + min_context_speaker_similarity: float | None = -1 + # 3. Supported existing NeMo options. shuffle: bool = False sample_rate: int = 16000 @@ -540,6 +546,11 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No TokenCountFilter(config.min_tokens, config.max_tokens, measure_total_length=config.measure_total_length) ) + # CER filtering, same as native NeMo dataloaders. + cuts = cuts.filter(CERFilter(config.max_cer)) + # Context speaker similarity filtering, same as native NeMo dataloaders. + cuts = cuts.filter(ContextSpeakerSimilarityFilter(config.min_context_speaker_similarity)) + if tokenizer is not None and config.pretokenize: cuts = cuts.filter(TokenPerSecondFilter(config.min_tps, config.max_tps)) cuts = cuts.filter(TokenPerTokenFilter(config.min_tpt, config.max_tpt)) diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index de2ad4074e1f..48944fdb89b2 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -262,6 +262,34 @@ def __call__(self, example) -> bool: else: return True # does not apply to text etc. +class CERFilter: + """ + Callable, returns ``True`` if a cut's CER is less than max_cer and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__(self, max_cer: float | None) -> None: + self.max_cer = ifnone(max_cer, float("inf")) + + def __call__(self, example) -> bool: + if isinstance(example, Cut) and len(example.supervisions) > 0 and example.supervisions[0].has_custom("cer"): + return example.supervisions[0].cer <= self.max_cer + else: + return True + +class ContextSpeakerSimilarityFilter: + """ + Callable, returns ``True`` if a cut's context speaker similarity is greater than min_context_speaker_similarity and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + def __init__(self, min_context_speaker_similarity: float | None) -> None: + self.min_context_speaker_similarity = ifnone(min_context_speaker_similarity, -1) + + def __call__(self, example) -> bool: + if isinstance(example, Cut) and len(example.supervisions) > 0 and example.supervisions[0].has_custom("context_speaker_similarity"): + return example.supervisions[0].context_speaker_similarity >= self.min_context_speaker_similarity + else: + return True class TokenCountFilter: """ diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index a75dda38cd02..b43f8ed032e7 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -32,8 +32,6 @@ ) from nemo.utils import logging -SUPPORTED_CODEC_MODEL_NAMES = ["21fpsCausalDecoder", "12fpsCausalDecoder"] - def setup_tokenizers(all_tokenizers_config, use_text_conditioning_tokenizer, mode='train'): # Being used in both model and worker_init_fn, so it is defined here @@ -97,10 +95,6 @@ class MagpieTTSLhotseDataset(torch.utils.data.Dataset): codec_model_samples_per_frame (int): The total downsampling factor of the audio codec model used to generate codes. Used for padding audio and calculating number of codec frames. - codec_model_name (str): Name identifier for the codec model, used to - determine the field name for loading cached codes (e.g., "codes_21fpsCausalDecoder"). - Defaults to "21fpsCausalDecoder". Supported values defined in - `SUPPORTED_CODEC_MODEL_NAMES`. audio_bos_id (int): Token ID representing the beginning-of-sequence (BOS) for target audio codes. audio_eos_id (int): Token ID representing the end-of-sequence (EOS) for target @@ -142,7 +136,6 @@ def __init__( sample_rate: int, volume_norm: bool = True, codec_model_samples_per_frame: int = None, - codec_model_name: str = "21fpsCausalDecoder", audio_bos_id: int = None, audio_eos_id: int = None, context_audio_bos_id: int = None, @@ -166,9 +159,6 @@ def __init__( self.context_audio_bos_id = context_audio_bos_id self.context_audio_eos_id = context_audio_eos_id - if codec_model_name not in SUPPORTED_CODEC_MODEL_NAMES: - raise ValueError(f"Invalid `codec_model_name`: {codec_model_name}.") - self.codec_model_name = codec_model_name self.codec_model_samples_per_frame = codec_model_samples_per_frame self.num_audio_codebooks = num_audio_codebooks @@ -230,9 +220,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: context_text_tokens_len_list = [] context_has_text_context_list = [] reward_list = [] - raw_text_list = [] - target_codes_field = f"codes_{self.codec_model_name}" - context_codes_field = f"context_codes_{self.codec_model_name}" + raw_text_list = [] # raw text here is the string of normalized text or text stored in the supervision segment. Used to distinguish from text tokens. for cut in cuts: speaker = cut.supervisions[0].speaker if not check_speaker_format(speaker): @@ -241,16 +229,19 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: dataset_name_list.append(dataset_name) # target audio or target codes - if self.load_cached_codes_if_available and cut.has_custom(target_codes_field): - audio_codes = torch.from_numpy(cut.load_custom(target_codes_field)).long() # (8, T) + if self.load_cached_codes_if_available and cut.has_custom("target_codes"): + # TODO @xueyang: applying Tensor.long(), i.e. torch.int64, is not necessary. + + # Note that we have segmented the audio according to offset and duration so that the audio codes should + # not specify start and duration again when calling TemporalArray.load(start, duration). Ensure start + # and duration are None to the load function. + audio_codes = torch.from_numpy(cut.target_codes.load()).long() # (C, T) + spec_len = audio_codes.shape[1] + 1 # +1 for EOS audio_bos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_bos_id, dtype=audio_codes.dtype) audio_eos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_eos_id, dtype=audio_codes.dtype) audio_codes = torch.cat([audio_bos_tensor, audio_codes, audio_eos_tensor], dim=1) audio_codes_len = audio_codes.shape[1] - spec_len = audio_codes.shape[1] + 1 # +1 for EOS - audio_codes_list.append( - audio_codes.T - ) # transpose to (T, 8) in order to use collate_matrices to process batch. + audio_codes_list.append(audio_codes.T) # transpose to (T, C) to use collate_matrices to process batch. audio_codes_len_list.append(audio_codes_len) else: # Only load audio if codes are not available @@ -270,10 +261,13 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: audio_len_list.append(audio_len) # context audio or context codes - if self.load_cached_codes_if_available and cut.has_custom(context_codes_field): - # TODO @xueyang: dev branch applied Tensor.long(), i.e. torch.int64 which is not necessary. - # load audios and text - context_audio_codes = torch.from_numpy(cut.load_custom(context_codes_field)).long() # (8, T) + if self.load_cached_codes_if_available and cut.has_custom("context_codes"): + # TODO @xueyang: applying Tensor.long(), i.e. torch.int64, is not necessary. + + # Note that we have segmented the audio according to offset and duration so that the audio codes should + # not specify start and duration again when calling TemporalArray.load(start, duration). Ensure start + # and duration are None to the load function. + context_audio_codes = torch.from_numpy(cut.context_codes.load()).long() # (8, T) # Sample random duration between self.context_duration_min and self.context_duration_max _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) _num_frames_to_slice = int( @@ -323,8 +317,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: context_audio_list.append(context_audio) context_audio_len_list.append(context_audio_len) else: - # We always want to have context_audio_codes if available for multi-encoder model. These are ignored - # for singlencoder model. + # We always want to have context_audio_codes if available for multi-encoder model. These are ignored for single-encoder model. # If context audio is not available, just use a dummy context_audio_codes # (Will be used in text context scenario) # TODO @xueyang: verified that this block should cover below 3 conditions which were handled well. @@ -343,9 +336,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: ) context_audio_codes = torch.cat([context_bos_tensor, context_eos_tensor], dim=1) context_audio_codes_len = context_audio_codes.shape[1] - context_audio_codes_list.append( - context_audio_codes.T - ) # transpose to (T, 8) in order to use collate_matrices to process batch. + context_audio_codes_list.append(context_audio_codes.T) # transpose to (T, C) to use collate_matrices to process batch. context_audio_codes_len_list.append(context_audio_codes_len) else: # @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes @@ -377,9 +368,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: if self.use_text_conditioning_tokenizer: if cut.supervisions[0].has_custom("context_text"): - context_text_tokens = self.text_conditioning_tokenizer(cut.supervisions[0].context_text)[ - 'input_ids' - ] + context_text_tokens = self.text_conditioning_tokenizer(cut.supervisions[0].context_text)['input_ids'] has_text_context = True else: context_text_tokens = self.text_conditioning_tokenizer("[NO TEXT CONTEXT]")['input_ids'] @@ -394,7 +383,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: else: # TODO @xueyang: It seems counter intuition if trimming the text context tokens to the required # context length. For example, the context_tokens after trimming may correspond to the partial - # context_text like "Speaker and Emotion: | Language:en Dataset(trimmed :Riva Speaker:Rodney_DROP |)" + # context_text like "Speaker and Emotion: | Language:en Dataset" where the following string is trimmed: ":Riva Speaker:Rodney_DROP |". context_text_tokens = context_text_tokens[:_required_len] context_text_tokens = torch.tensor(context_text_tokens, dtype=torch.int32) context_text_tokens_len = context_text_tokens.shape[0] @@ -403,15 +392,18 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: context_has_text_context_list.append(has_text_context) # tokenize transcript - # TODO @xueyang: temporally apply raw text. will check to change if normalized text is available. - raw_text = cut.supervisions[0].text - raw_text_list.append(raw_text) + # there may exist "normalized_text" in the suprvisionsegement. Prioritize it over "text" if available. + if cut.supervisions[0].has_custom("normalized_text"): + text_str = cut.supervisions[0].normalized_text + else: + text_str = cut.supervisions[0].text + raw_text_list.append(text_str) if cut.has_custom("tokenizer_names"): # Pick a random tokenizer from the list of tokenizers tokenizer_name = random.choice(cut.tokenizer_names) else: tokenizer_name = "english_phoneme" # Default to english phoneme tokenizer - tokens = self.text_tokenizer.encode(text=raw_text, tokenizer_name=tokenizer_name) + tokens = self.text_tokenizer.encode(text=text_str, tokenizer_name=tokenizer_name) tokens = tokens + [self.eos_id] # Not adding BOS id tokens = torch.tensor(tokens, dtype=torch.int32) text_len = tokens.shape[0] @@ -419,7 +411,6 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: token_len_list.append(text_len) if self.include_align_prior: - # align_prior = self.beta_binomial_interpolator(spec_len, text_len) align_prior = beta_binomial_prior_distribution( phoneme_count=text_len, mel_count=spec_len, scaling_factor=self.prior_scaling_factor ) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index b5bf878d3a61..9b650ee4207b 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -21,7 +21,7 @@ from hydra.utils import instantiate from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger -from omegaconf import DictConfig, open_dict +from omegaconf import DictConfig, OmegaConf, open_dict from torch import nn from torch.utils.data import get_worker_info @@ -1020,13 +1020,13 @@ def replace_beta_binomial_prior_with_binarized(self, attn_prior, aligner_attn_ha def get_binarized_prior_matrix(self, aligner_attn_soft, audio_lens, text_lens): # aligner_attn_soft B, 1, audio_timesteps, text_timesteps if self.binarize_attn_method == 'nemo_binarize': - logging.info("Binarizing attention using nemo_binarize") + logging.debug("Binarizing attention using nemo_binarize") binarize_repeat_audio_factor = self.binarize_repeat_audio_factor aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave(binarize_repeat_audio_factor, dim=2) # B, 1, 2*audio_timesteps, text_timesteps aligner_attn_hard = binarize_attention_parallel(aligner_attn_soft_repeated, text_lens, audio_lens*binarize_repeat_audio_factor).squeeze(1) # B, 2*audio_timesteps, text_timesteps aligner_attn_hard = aligner_attn_hard[:, ::2, :] # B, audio_timesteps, text_timesteps elif self.binarize_attn_method == 'argmax': - logging.info("Binarizing attention using argmax") + logging.debug("Binarizing attention using argmax") aligner_attn_hard = torch.argmax(aligner_attn_soft.squeeze(1), dim=-1) aligner_attn_hard = torch.nn.functional.one_hot(aligner_attn_hard, num_classes=aligner_attn_soft.size(-1)).float() else: @@ -1265,7 +1265,7 @@ def training_step(self, batch, batch_idx): batch_info_dict = { "train/batch_size": batch_size, "train/text_token_max_len": text_token_max_len, - "train/text_token_total_num_in_batch": text_token_total_num, + "train/text_token_total_num_in_batch": text_token_total_num.item(), "train/text_token_pad_ratio_percent_in_batch": 100 * (1 - text_token_total_num / (batch_size * text_token_max_len)), } @@ -1274,7 +1274,7 @@ def training_step(self, batch, batch_idx): audio_codes_total_num = batch["audio_codes_lens"].sum() batch_info_dict.update({ "train/audio_codes_max_len": audio_codes_max_len, - "train/audio_codes_total_num_in_batch": audio_codes_total_num, + "train/audio_codes_total_num_in_batch": audio_codes_total_num.item(), "train/audio_codes_pad_ratio_percent_in_batch": 100 * (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)), }) else: @@ -1282,7 +1282,7 @@ def training_step(self, batch, batch_idx): audio_samples_total_num = batch["audio_lens"].sum() batch_info_dict.update({ "train/audio_samples_max_len": audio_samples_max_len, - "train/audio_samples_total_num_in_batch": audio_samples_total_num, + "train/audio_samples_total_num_in_batch": audio_samples_total_num.item(), "train/audio_samples_pad_ratio_percent_in_batch": 100 * (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)), }) @@ -1846,7 +1846,6 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D sample_rate=self.sample_rate, volume_norm=dataset_cfg.volume_norm, codec_model_samples_per_frame=self.codec_model_samples_per_frame, - codec_model_name=self.cfg.codec_model_name, audio_bos_id=self.audio_bos_id, audio_eos_id=self.audio_eos_id, context_audio_bos_id=self.context_audio_bos_id, @@ -1874,6 +1873,14 @@ def setup_training_data(self, dataset_cfg): if dataset_cfg.get("use_lhotse", False): # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also # cfg is a classifier-free guidance. + + # specify target sampling rate the same as codec model's because lhotse config defaults 16_000. + if not isinstance(dataset_cfg, DictConfig): + dataset_cfg = OmegaConf.create(dataset_cfg) + OmegaConf.set_struct(dataset_cfg.dataset, False) + dataset_cfg.dataset.update({"sample_rate": self.sample_rate}) + OmegaConf.set_struct(dataset_cfg.dataset, True) + self._train_dl = self.get_lhotse_dataloader(dataset_cfg, mode='train') else: dataset = self.get_dataset(dataset_cfg, dataset_type='train') @@ -1898,6 +1905,12 @@ def setup_training_data(self, dataset_cfg): def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: if dataset_cfg.get("use_lhotse", False): + # specify target sampling rate the same as codec model's because lhotse config defaults 16_000. + if not isinstance(dataset_cfg, DictConfig): + dataset_cfg = OmegaConf.create(dataset_cfg) + OmegaConf.set_struct(dataset_cfg.dataset, False) + dataset_cfg.dataset.update({"sample_rate": self.sample_rate}) + OmegaConf.set_struct(dataset_cfg.dataset, True) data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test') else: dataset = self.get_dataset(dataset_cfg, dataset_type='test') @@ -1920,11 +1933,11 @@ def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: ) return data_loader - def setup_validation_data(self, cfg): - self._validation_dl = self._setup_test_dataloader(cfg) + def setup_validation_data(self, dataset_cfg): + self._validation_dl = self._setup_test_dataloader(dataset_cfg) - def setup_test_data(self, cfg): - self._test_dl = self._setup_test_dataloader(cfg) + def setup_test_data(self, dataset_cfg): + self._test_dl = self._setup_test_dataloader(dataset_cfg) @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: From 58b809f090d1fc611dbdfb3403e79ce1747a0e0b Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Mon, 16 Jun 2025 15:28:05 -0700 Subject: [PATCH 049/113] [magpietts][lhotse_v2] Adding scripts of converting NeMo manifests to Lhotse Shars, and speedup improvements for codec model inference. (#13821) * [magpietts][lhotse_v2] added 3 scripts and changes for collections Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [magpietts][lhotse] updade README to show details how toto run scripts. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * replace torch.no_grad context manager with torch.inference_mode. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- nemo/collections/common/parts/utils.py | 6 +- scripts/magpietts/README_lhotse.md | 409 +++++--- .../magpietts/convert_nemo_to_lhotse_shar.py | 464 --------- .../create_lhotse_shar_from_nemo_manifest.py | 504 ++++++++++ .../extend_lhotse_shards_with_audio_codes.py | 912 ++++++++++++++++++ ...extend_nemo_manifest_with_context_audio.py | 873 +++++++++++++++++ 6 files changed, 2584 insertions(+), 584 deletions(-) delete mode 100644 scripts/magpietts/convert_nemo_to_lhotse_shar.py create mode 100644 scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py create mode 100644 scripts/magpietts/extend_lhotse_shards_with_audio_codes.py create mode 100644 scripts/magpietts/extend_nemo_manifest_with_context_audio.py diff --git a/nemo/collections/common/parts/utils.py b/nemo/collections/common/parts/utils.py index c00de27c55bd..ad56b28e61d3 100644 --- a/nemo/collections/common/parts/utils.py +++ b/nemo/collections/common/parts/utils.py @@ -142,13 +142,13 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor): batch_size, *_, max_lengths = tensor.shape if len(tensor.shape) == 2: - mask = torch.ones(batch_size, max_lengths).cumsum(dim=-1).type_as(lengths) + mask = torch.ones(batch_size, max_lengths, dtype=lengths.dtype, device=lengths.device).cumsum(dim=-1) mask = mask <= einops.rearrange(lengths, 'B -> B 1') elif len(tensor.shape) == 3: - mask = torch.ones(batch_size, 1, max_lengths).cumsum(dim=-1).type_as(lengths) + mask = torch.ones(batch_size, 1, max_lengths, dtype=lengths.dtype, device=lengths.device).cumsum(dim=-1) mask = mask <= einops.rearrange(lengths, 'B -> B 1 1') elif len(tensor.shape) == 4: - mask = torch.ones(batch_size, 1, 1, max_lengths).cumsum(dim=-1).type_as(lengths) + mask = torch.ones(batch_size, 1, 1, max_lengths, dtype=lengths.dtype, device=lengths.device).cumsum(dim=-1) mask = mask <= einops.rearrange(lengths, 'B -> B 1 1 1') else: raise ValueError('Can only mask tensors of shape B x L, B x D x L and B x D1 x D2 x L') diff --git a/scripts/magpietts/README_lhotse.md b/scripts/magpietts/README_lhotse.md index f56eb1cdd4dd..9a8eb7735d1b 100644 --- a/scripts/magpietts/README_lhotse.md +++ b/scripts/magpietts/README_lhotse.md @@ -1,174 +1,349 @@ -This guidance describes general steps on converting NeMo datasets to Lhotse Shar datasets for training -and validating Magpie-TTS. +This guidance describes the new Lhotse Shar process for converting NeMo datasets to Lhotse Shar datasets, designed for +training and validating Magpie-TTS. This new version significantly reduces computation overhead by using rank-balanced +workloading and independent writing across parallel processes. It also separate the processes to CPU-only nodes and +GPU-only nodes accordingly. ## Creating New Lhotse Shar Data -Step 1: reformatting `speaker` field in the NeMo manifest to pass the format check as the function defined, + +The process involves four main steps: +1. **Prepare Input Manifests (on CPU nodes):** Standardize the input NeMo manifests for each dataset. +2. **Extend Manifests with Context Audio (on GPU nodes):** Enhance the NeMo manifests by adding context audio information. +3. **Create Lhotse Shards (on CPU nodes):** Convert the extended NeMo manifests into Lhotse shards. +4. **Extend Shards with Audio Codes (on GPU nodes):** Process the Lhotse shards to extract and include audio codes (audio codec extraction). + +### Step 1: Prepare Input Manifests + +This first step runs on **CPU nodes** and is responsible for standardizing the input NeMo manifests for each dataset. This may involve consolidating multiple input files or reformatting entries. It's a preparatory step to ensure the manifest is in the correct format for the subsequent stages. + +*Note: The actual implementation for this step ([`prep_input_manifest.py`](https://gitlab-master.nvidia.com/xueyang/nemo-tts-artifacts-registry/-/blob/model_release_2505/model_release_2505/data_prep/hifitts2/prep_input_manifest_iad.py) in the internal scripts) is highly specific to the dataset and environment. Users should create their own script to prepare a standardized manifest file as input for Step 2.* + +A crucial part of this step is to ensure the `speaker` field in the NeMo manifest conforms to the required format: ```python def check_speaker_format(item: str): # enforce the format as example like "| Language:en Dataset:HiFiTTS Speaker:9136_other |". pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" return bool(re.match(pattern, item)) ``` - -Step 2: create Lhotse Shar dataset by running, + +#### Checkout the Outputs of `hifitts2/prep_input_manifest.py` ```bash -# codec -CODEC_MODEL_NAME="21fpsCausalDecoder" -CODEC_MODEL_PATH="/codecs/21fps_causal_codecmodel.nemo" -CODEC_FRAME_RATE=21.5 -SAMPLE_RATE=22050 -PAD_MULTIPLE=1024 +$ tree -L 1 -P '*.json|*.txt' hifitts2/nemo_manifest/ +hifitts2/nemo_manifest/ +├── hifitts2_all_splits.json +├── hifitts2_dev_seen.json +├── hifitts2_dev_unseen.json +├── hifitts2_test_seen.json +├── hifitts2_test_unseen.json +├── hifitts2_train.json # This is the standardized NeMo manifest used for the following steps. +├── manifest_empty_normalized_text_fields.json +├── manifest_librosa_error.json +├── manifest_mismatched_audio_duration.json +├── manifest_missing_asr_metrics.json +├── manifest_missing_audio_files.json +├── manifest_missing_required_fields.json +└── stats.txt # This helps to understand the success and failure records. +``` + +### Step 2: Extend NeMo Manifest with Context Audio + +This step runs on **GPU nodes**. It enhances the standardized NeMo manifest from Step 1 by adding context audio information. + +Improvements over older recipes include: +- Speaker embedding extraction is run on the fly, using `torch.matmul` to compute a similarity matrix. +- It recursively finds the next best context audio if the top candidate is unsuitable, preserving more data. +- It is scaling-friendly by pre-allocating a distinct subset of speaker records to each GPU rank for balanced workloads using a greedy bin-packing strategy. +- Manifests are written out in a buffered way to reduce I/O calls. + +#### Example command: +```bash +# General setup +CODE_DIR="/workspace/NeMo" +export PYTHONPATH="${CODE_DIR}:${PYTHONPATH}" +cd ${CODE_DIR} -# trainer +# Script parameters +INPUT_MANIFEST="/path/to/hifitts2/nemo_manifest/hifitts2_train.json" # From Step 1 +AUDIO_BASE_DIR="/path/to/audio/files" +OUTPUT_DIR="/path/to/hifitts2/nemo_manifest" +DATASET_NAME="hifitts2" # e.g., hifitts, libritts, etc. Used for speaker ID parsing. +CONTEXT_MIN_DURATION=3.0 +CONTEXT_MIN_SSIM=0.6 +BATCH_SIZE=256 +FLUSH_THRESHOLD_ITEMS=256 +NUM_WORKERS=8 DEVICES=-1 NUM_NODES=1 -BATCH_SIZE=48 -NUM_WORKERS=10 -SHARD_SIZE=4096 +WANDB_ENTITY="xyz" +WANDB_PROJECT="xyz" +WANDB_NAME="xyz" -# code -CODE_DIR="/workspace/NeMo" +echo "****** STEP 2: Extending NeMo Manifest with Context Audio ******" +python scripts/magpietts/extend_nemo_manifest_with_context_audio.py \ + --dataset-name ${DATASET_NAME} \ + --manifest ${INPUT_MANIFEST} \ + --audio-base-dir ${AUDIO_BASE_DIR} \ + --output-dir ${OUTPUT_DIR} \ + --flush-threshold-items ${FLUSH_THRESHOLD_ITEMS} \ + --context-min-duration ${CONTEXT_MIN_DURATION} \ + --context-min-ssim ${CONTEXT_MIN_SSIM} \ + --batch-size ${BATCH_SIZE} \ + --devices ${DEVICES} \ + --num-nodes ${NUM_NODES} \ + --num-workers ${NUM_WORKERS} \ + --wandb-entity ${WANDB_ENTITY} \ + --wandb-project ${WANDB_PROJECT} \ + --wandb-name ${WANDB_NAME} +``` -# NeMo manifest -MANIFEST_PATH="/manifests/hifitts_train_withContextAudioMinDur3MinSSIM0.6.json" +#### Checkout the Outputs +```bash +$ tree -L 1 hifitts2/nemo_manifest/extend_nemo_manifest_with_context_audio/ +hifitts2/nemo_manifest/extend_nemo_manifest_with_context_audio/ +├── hifitts2_train_rank0.json +├── hifitts2_train_rank1.json +├── hifitts2_train_rank2.json +├── hifitts2_train_rank3.json +├── hifitts2_train_rank4.json +├── hifitts2_train_rank5.json +├── hifitts2_train_rank6.json +├── hifitts2_train_rank7.json +└── hifitts2_train_withContextAudioMinDur3.0MinSSIM0.6.json # This is the NeMo manifest used for the following steps. +``` -# audio base dir -AUDIO_BASE_DIR="/audio/hi_fi_tts_v0" -# save dir for Shar -SAVE_DIR="/data_shar_train" +### Step 3: Create Lhotse Shards from NeMo Manifest -echo "*******STARTING********" +This step runs on **CPU nodes**. It converts the extended NeMo manifests (from Step 2) into Lhotse shards. + +Key features: +- Processes chunks of manifest entries, loads audio, and writes corresponding shard files for cuts, target audio, and context audio. +- Designed to be run in parallel worker processes. +- Loads and writes audio iteratively to save memory. + +#### Example command: +```bash +# General setup +CODE_DIR="/workspace/NeMo" +export PYTHONPATH="${CODE_DIR}:${PYTHONPATH}" cd ${CODE_DIR} + +# Script parameters +EXTENDED_MANIFEST_PATH="/path/to/hifitts2/nemo_manifest/extend_nemo_manifest_with_context_audio/hifitts2_train_withContextAudioMinDur3.0MinSSIM0.6.json" # From Step 2 +AUDIO_BASE_DIR="/path/to/audio/files" +SAVE_DIR="/path/to/lhotse_shar_output" +NUM_WORKERS=16 # Number of CPU cores +SHARD_SIZE=4096 + +echo "****** STEP 3: Creating Lhotse Shards from NeMo Manifest ******" +python scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py \ + --manifest-path ${EXTENDED_MANIFEST_PATH} \ + --audio-base-dir ${AUDIO_BASE_DIR} \ + --output-dir ${SAVE_DIR} \ + --num-jobs ${NUM_WORKERS} \ + --processing-chunk-size ${SHARD_SIZE} \ + --audio-format 'flac' \ + --log-level 'INFO' +``` + +#### Checkout the outpus + +```bash +$ tree -L 3 -P "*.000000.*" hifitts2/lhotse_shar/{cuts,target_audio,context_audio} +hifitts2/lhotse_shar/cuts +└── cuts.000000.jsonl.gz +hifitts2/lhotse_shar/target_audio +└── recording.000000.tar +hifitts2/lhotse_shar/context_audio +└── recording.000000.tar +``` + +### Step 4: Extend Lhotse Shards with Audio Codes + +This final step runs on **GPU nodes**. It processes the Lhotse shards created in Step 3 to extract and add audio codes. + +Improvements include: +- Pre-allocation of Lhotse shards to each rank, with each rank processing and writing independently. +- Pre-allocation of padded audio tensors, avoiding looped calls to `torch.func.pad`. +- Avoids redundant zero-padding that was present in older recipes. + +#### Example command: +```bash +# General setup +CODE_DIR="/workspace/NeMo" export PYTHONPATH="${CODE_DIR}:${PYTHONPATH}" -echo "Starting Codec Extraction..." -python scripts/magpietts/convert_nemo_to_lhotse_shar.py - --manifest ${MANIFEST_PATH} - --audio_base_dir ${AUDIO_BASE_DIR} - --save_dir ${SAVE_DIR} - --codec_model_name ${CODEC_MODEL_NAME} - --codec_model_path ${CODEC_MODEL_PATH} - --codec_frame_rate ${CODEC_FRAME_RATE} - --sample_rate ${SAMPLE_RATE} - --pad_multiple ${PAD_MULTIPLE} - --devices ${DEVICES} - --num_nodes ${NUM_NODES} - --batch_size ${BATCH_SIZE} - --num_workers ${NUM_WORKERS} - --shard_size ${SHARD_SIZE} -``` - -Step 3: check the files by looking at the folder, -```shell -Examples of shard files: -$ tree data_shar_train/ - - cuts.000000.jsonl.gz # Lhotse manifest. - - codes_21fpsCausalDecoder.000000.tar # target codec codes. - - recording.000000.tar # target audio waveforms. - - context_codes_21fpsCausalDecoder.000000.tar # context audio codec codes. - - context_recording.000000.tar # context audio waveforms. +cd ${CODE_DIR} + +# Codec parameters +CODEC_MODEL_NAME="21fpsCausalDecoder" +CODEC_MODEL_PATH="/path/to/your/codec_model.nemo" +CODEC_FRAME_RATE=21.5 + +# Trainer parameters +DEVICES=-1 # Number of GPUs, -1 for all +NUM_NODES=1 +BATCH_SIZE=64 +WANDB_ENTITY="xyz" +WANDB_PROJECT="xyz" +WANDB_NAME="xyz" + +# Path parameters +SHARD_DIR="/path/to/hifitts2/lhotse_shar" # From Step 3 + +echo "****** STEP 4: Extending Lhotse Shards with Audio Codes ******" +python scripts/magpietts/extend_lhotse_shards_with_audio_codes.py \ + --cuts-dir ${SHARD_DIR}/cuts \ + --target-audio-dir ${SHARD_DIR}/target_audio \ + --context-audio-dir ${SHARD_DIR}/context_audio \ + --output-dir ${SHARD_DIR} \ + --codec-model-name ${CODEC_MODEL_NAME} \ + --codec-model-path ${CODEC_MODEL_PATH} \ + --codec-frame-rate ${CODEC_FRAME_RATE} \ + --devices ${DEVICES} \ + --num-nodes ${NUM_NODES} \ + --batch-size ${BATCH_SIZE} \ + --wandb-entity ${WANDB_ENTITY} \ + --wandb-project ${WANDB_PROJECT} \ + --wandb-name ${WANDB_NAME} \ + --log-level 'INFO' \\ +``` + +### Checking the Outputs + +After running all four steps, you can check the files by looking at the output directory specified in Steps 3 and 4. + +```bash +# Examples of shard files: +$ tree -L 3 -P '*.000000.*' -I log hifitts2/lhotse_shar +hifitts2/lhotse_shar +├── codes_21fpsCausalDecoder # This is the subdir for audio codec codes. +│   ├── context_codes +│   │   └── codes.000000.tar # context audio codec codes. +│   └── target_codes +│   └── codes.000000.tar # target codec codes. +├── context_audio +│   └── recording.000000.tar # context audio waveforms. +├── cuts +│   └── cuts.000000.jsonl.gz # Lhotse manifest. +└── target_audio + └── recording.000000.tar # target audio waveforms. ``` When peek one of the item from `cuts.000000.jsonl.gz`, you should expect the structure as, ```python +In [4]: cutset = CutSet.from_shar(fields={"cuts": ["hifitts2/lhotse_shar/cuts/cuts.000000.jsonl.gz"], "target_audio": ["hifitts2/lhotse_shar/target_audio/recording.000000.tar"], "context_audio": ["h + ...: ifitts2/lhotse_shar/context_audio/recording.000000.tar"], "target_codes": ["hifitts2/lhotse_shar/codes_21fpsCausalDecoder/target_codes/codes.000000.tar"], "context_codes": ["hifitts2/lhotse_ + ...: shar/codes_21fpsCausalDecoder/context_codes/codes.000000.tar"]}) + +In [5]: cuts_list = [cut for cut in cutset] + +In [12]: from rich import print + +In [13]: print(cuts_list[0]) MonoCut( - id='cut-audio-11614_other-12352-prideofjennico_01_castle_0000', - start=0, - duration=6.16, + id='cut-rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27-0.00-3.49', + start=0.0, + duration=3.49, channel=0, supervisions=[ SupervisionSegment( - id='sup-audio-11614_other-12352-prideofjennico_01_castle_0000', - recording_id='audio-11614_other-12352-prideofjennico_01_castle_0000', + id='sup-rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27', + recording_id='rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27', start=0.0, - duration=6.16, + duration=3.49, channel=0, - text='late in the year seventeen seventy one as the wind rattles the casements with impotent clutch', + text='he was perhaps five years my senior', language='en', - speaker='| Language:en Dataset:HiFiTTS Speaker:11614 |', + speaker='| Language:en Dataset:hifitts2 Speaker:9216 |', gender=None, - custom={}, + custom={ + 'normalized_text': 'He was perhaps five years my senior.', + 'text_source': 'mls', + 'wer': 0.0, + 'cer': 0.0, + 'speaker_count': 1, + 'bandwidth': 13953, + 'set': 'train', + 'context_speaker_similarity': 0.802838921546936, + 'context_audio_offset': 0.0, + 'context_audio_duration': 12.2, + 'context_audio_text': 'Vision of Helen," he called it, I believe.... The oblique stare of the hostile Trojans. Helen coifed with flame. Menelaus.', + 'context_audio_normalized_text': 'Vision of Helen," he called it, I believe.... The oblique stare of the hostile Trojans. Helen coifed with flame. Menelaus.', + 'context_recording_id': 'rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_30' + }, alignment=None ) ], features=None, recording=Recording( - id='audio-11614_other-12352-prideofjennico_01_castle_0000', - sources=[ - AudioSource( - type='memory', - channels=[0], - source='' - ) - ], - sampling_rate=44100, - num_samples=271656, - duration=6.16, + id='rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27', + sources=[AudioSource(type='file', channels=[0], source='/audio/9216/8716/9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27.flac')], + sampling_rate=22050, + num_samples=76955, + duration=3.4900226757369612, channel_ids=[0], transforms=None ), custom={ - 'codes_21fpsCausalDecoder': TemporalArray( - array=Array( - storage_type='memory_npy', - storage_path='', - storage_key='', - shape=[8, 133] - ), - temporal_dim=1, + 'target_audio': Recording( + id='cut-rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27-0.00-3.49', + sources=[AudioSource(type='memory', channels=[0], source='')], + sampling_rate=22050, + num_samples=76955, + duration=3.49, + channel_ids=[0], + transforms=None + ), + 'context_audio': Recording( + id='context_cut-rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_30-0.00-12.20', + sources=[AudioSource(type='memory', channels=[0], source='')], + sampling_rate=22050, + num_samples=269010, + duration=12.2, + channel_ids=[0], + transforms=None + ), + 'target_codes': TemporalArray( + array=Array(storage_type='memory_npy', storage_path='', storage_key='', shape=[8, 76]), + temporal_dim=-1, frame_shift=0.046511627906976744, start=0 ), - 'context_codes_21fpsCausalDecoder': TemporalArray( - array=Array( - storage_type='memory_npy', - storage_path='', - storage_key='', - shape=[8, 138] - ), - temporal_dim=1, + 'context_codes': TemporalArray( + array=Array(storage_type='memory_npy', storage_path='', storage_key='', shape=[8, 263]), + temporal_dim=-1, frame_shift=0.046511627906976744, start=0 ), - 'context_recording': Recording( - id='audio-11614_other-12220-barontrump_31_lockwood_0096', - sources=[ - AudioSource( - type='memory', - channels=[0], - source='' - ) - ], - sampling_rate=44100, - num_samples=282240, - duration=6.4, - channel_ids=[0], - transforms=None - ), - 'shard_origin': PosixPath('cuts.000000.jsonl.gz'), + 'shard_origin': 'hifitts2/lhotse_shar/cuts/cuts.000000.jsonl.gz', 'shar_epoch': 0 } ) ``` -## Appending New Codec Codes to Existing Lhotse Manifest -TBD. In genenral, the solution is to load existing cuts of shards, attach the new codec codes to the -MonoCut's `custom` field, and write cuts and new codec codes into shard files. This should uses the -same index of shards. +## Extending the Existing Lhotse Shar with New Audio Codec Codes +Given existing Lhotse Shar, i.e. cuts/target_audio/context_audio, you could just run the Python script +`scripts/magpietts/extend_lhotse_shards_with_audio_codes.py` by overriding with other audio codec models. The whole +process should be the same as Step 4 as mentioned above. ## (Internal Only) Sharing Slurm Job Sub Scripts to Create Lhotse Shar -All scripts are stored in -https://gitlab-master.nvidia.com/xueyang/nemo-tts-artifacts-registry/-/tree/main/data_prep_lhotse . +The internal scripts for submitting these steps as Slurm jobs can be found in the GitLab repository `nemo-tts-artifacts-registry` +repository, i.e. https://gitlab-master.nvidia.com/xueyang/nemo-tts-artifacts-registry/-/tree/model_release_2505/model_release_2505/data_prep. These scripts are tailored for specific cluster environments but can be adapted for other systems. ```shell -$ tree . -. -├── extract_audioCodec_21fpsCausalDecoder_eos.sub -├── hifitts2_extract_audioCodec_21fpsCausalDecoder_eos.sub -├── README_lhotse.md -├── reserve_interactive_node.sh -└── submit_jobs_for_all_datasets.sh - -$ bash submit_jobs_for_all_datasets.sh +$ tree -L 1 gitlab/nemo-tts-artifacts-registry/model_release_2505/data_prep/ +gitlab/nemo-tts-artifacts-registry/model_release_2505/data_prep/ +├── 1_submit_jobs_prep_input_manifest_iad.sh +├── 2_submit_jobs_extend_nemo_manifest_with_context_audio_iad.sh +├── 3_submit_jobs_create_lhotse_shards_from_nemo_manifest_iad.sh +├── 4_submit_jobs_extend_lhotse_shards_with_audio_codes_iad.sh +├── hifitts +├── hifitts2 +├── jhsdGtc20Amp20Keynote +├── libritts +├── librittsDevClean +├── nvyt2505 +├── README.md +├── rivaEmmaMeganSeanTom +└── rivaLindyRodney ``` diff --git a/scripts/magpietts/convert_nemo_to_lhotse_shar.py b/scripts/magpietts/convert_nemo_to_lhotse_shar.py deleted file mode 100644 index ec3e4b98ea83..000000000000 --- a/scripts/magpietts/convert_nemo_to_lhotse_shar.py +++ /dev/null @@ -1,464 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Example entry in `/home/xueyang/workspace/pretrain/data_prep/hifitts2/manifests/train_manifest_withContextAudioMinDur3.json` -{ - "audio_filepath": "100/2315/100_2315_sea_fairies_0812_librivox-01_baum_sea_fairies_0.flac", - "duration": 6.2, - "speaker": "| Language:en Dataset:HiFiTTS2 Speaker:100 |", - "text": "THE oceans are big and broad. I believe two thirds of the earth's surface is covered with water.", - "normalized_text": "THE oceans are big and broad. I believe two thirds of the earth's surface is covered with water.", - "text_source": "book", - "bandwidth": 13092, - "snr1": 41.27, - "snr2": 41.05, - "snr3": 32.58, - "snr4": 22.28, - "is_segmented": true, - "wer": 0, - "cer": 0, - "ins": 0, - "del": 0, - "sub": 0, - "speaker_count": 1, - "chapter_id": "01", - "context_speaker_similarity": 0.9059218168258667, - "context_audio_filepath": "100/2315/100_2315_sea_fairies_0812_librivox-01_baum_sea_fairies_1.flac", - "context_audio_duration": 6.08 -} - -Goal: avoid to save inodes quota by tarring individual files for audio codecs, audio waveforms and/or speaker embeddings. - We can decide if remove audio waveform files later. - -Examples of shard files: -$ tree data-shar-train/ - - cuts.000000.jsonl.gz: add all and exclude unnecessary fields. - - codes_21fpsCausalDecoder.000000.tar - - recording.000000.tar: not used during training, but worth to tar them so save inodes quota and for future applications. - - context_codes_21fpsCausalDecoder.000000.tar - - context_recording.000000.tar - - context_spk_embed.000000.tar (optional): speaker embedding is not used during training/validation. - - spk_embed.000000.tar (optional): speaker embedding is not used during training/validation. -""" - -import argparse -import os -import re -from functools import partial -from pathlib import Path - -import lightning.pytorch as pl -import torch -from lhotse import MonoCut, Recording, SupervisionSegment -from lhotse.shar.writers.shar import AudioTarWriter, SharWriter -from lightning.pytorch import Trainer -from lightning.pytorch.callbacks import BasePredictionWriter -from lightning.pytorch.strategies import DDPStrategy -from torch.utils.data import DataLoader, Dataset - -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest -from nemo.collections.tts.models import AudioCodecModel - -MAYBE_EXTRA_METADATA_IN_MANIFEST = ["normalized_text", "speaker_count", "cer", "wer"] - - -def check_speaker_format(item: str): - # enforce the format as example like "| Language:en Dataset:HiFiTTS Speaker:9136_other |". - pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" - return bool(re.match(pattern, item)) - - -def get_recording_id(audio_base_dir: str, path: Path): - # the recording id is defined as the concatenation of relative audio filepath with a hyphen delimiter. - return path.relative_to(audio_base_dir).with_suffix("").as_posix().replace("/", "-") - - -class SharPredictionWriter(BasePredictionWriter): - def __init__( - self, - output_dir: str, - codec_model_name: str, - codec_frame_rate: float, - audio_base_dir: str, - fields: dict, - shard_size: int = 1000, - ): - super().__init__(write_interval="batch") - self.output_dir = output_dir - self.codec_model_name = codec_model_name - self.codec_frame_rate = codec_frame_rate - self.fields = fields - self.shard_size = shard_size - self.batch_counter = 0 - self.shar_writer = None - self.context_recording_writer = None - self.is_initialized = False - self.recording_id_fn = partial(get_recording_id, audio_base_dir) - - # Add a buffer with the shard size to accumulate cuts before writing to disk. - self.cuts_buffer = list() - self.buffer_size = shard_size - - def setup(self, trainer, pl_module, stage=None): - if not self.is_initialized: - # Only initialize the SharWriter and AudioTarWriter on rank 0 - if trainer.global_rank == 0: - os.makedirs(self.output_dir, exist_ok=True) - - # Initialize SharWriter - self.shar_writer = SharWriter( - output_dir=self.output_dir, fields=self.fields, shard_size=self.shard_size - ) - self.shar_writer.__enter__() - - # Initialize AudioTarWriter to store context recording as a workaround. - # TODO @xueyang: Without this, the process would be blocked because, - # When target duration is specified in MonoCut, the error will happen iff context duration < target duration - # mostly because the cut tries to trim the context_recording to the same duration as target. No errors - # were observed when context duration > target duration. Ref is https://nvidia.slack.com/archives/D068LR4TWUW/p1741817511544239 - self.context_recording_writer = AudioTarWriter( - pattern=os.path.join(self.output_dir, "context_recording.%06d.tar"), - shard_size=self.shard_size, - format="flac", - ) - self.context_recording_writer.__enter__() - - self.is_initialized = True - - def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx): - # prepare cuts from each rank - pred_cuts = self.convert_prediction_to_cuts(prediction) - # Gather predictions from all ranks - gathered_objects = [None] * trainer.world_size - torch.distributed.all_gather_object(gathered_objects, pred_cuts) - - # Only rank 0 writes to disk - if trainer.global_rank == 0: - for _pred_cuts, _context_recordings in gathered_objects: - if _pred_cuts is None or _context_recordings is None: - raise RuntimeError("Received None from all_gather_object") - - # Buffer the cuts - self.cuts_buffer.extend(list(zip(_pred_cuts, _context_recordings))) - - # Write when buffer is full - if len(self.cuts_buffer) >= self.buffer_size: - self._write_buffer() - - def _write_buffer(self): - """Write accumulated cuts from buffer""" - for cut, recording in self.cuts_buffer: - self.shar_writer.write(cut) - self.context_recording_writer.write( - key=cut.id, - value=recording.load_audio(), - sampling_rate=recording.sampling_rate, - manifest=recording, - ) - self.batch_counter += 1 - - # Clear the buffer - self.cuts_buffer = list() - - def convert_prediction_to_cuts(self, prediction): - # Extra useful metadata may exist in some manifests, so better to keep them for future usage. - meta_fields = { - meta_field: prediction[meta_field] - for meta_field in MAYBE_EXTRA_METADATA_IN_MANIFEST - if meta_field in prediction - } - - # This should convert predictions to Cut objects for Lhotse - cuts = list() - context_recordings = list() - - # batch process recordings and codes here. - # target recording - target_recordings = [ - Recording.from_file( - path=audio_filepath, - recording_id=self.recording_id_fn, - ) - for audio_filepath in prediction["target_audio_filepath"] - ] - context_recordings = [ - Recording.from_file( - path=audio_filepath, - recording_id=self.recording_id_fn, - ) - for audio_filepath in prediction["context_audio_filepath"] - ] - - # Create supervisions in batch - supervisions = [ - SupervisionSegment( - id=f"sup-{rec.id}", - recording_id=rec.id, - start=0.0, - duration=rec.duration, - channel=0, - text=text, - speaker=spk, - language=lang, - custom={key: val[idx] for key, val in meta_fields.items()} if meta_fields else None, - ) - for idx, (rec, text, spk, lang) in enumerate( - zip(target_recordings, prediction["text"], prediction["speaker"], prediction["language"]) - ) - ] - - # Create cuts in batch - # TODO @xueyang: should file a bug report to `attach_tensor` function. When `temporal_dim=-1`, the tensor is not - # attached correctly. For example, I found that `cuts[0].codes_21fpsCausalDecoder.load()` and - # `cuts[0].load_custom("codes_21fpsCausalDecoder")` returns different arrays, with different shapes. But the former - # returned expected (8,5) shape, while the latter returned (5,5). I also find that, after write shar files, and - # when i load codes using `CutSet.from_shar()` and no matter which load functions I used, they are all shape of (5,5) - # instead of (8,5). In any case, using default `temporal_dim` and `frame_shift` addressed this issue. - cuts = [ - MonoCut( - id=f"cut-{rec.id}", - start=0.0, - duration=rec.duration, - recording=rec, - channel=0, - supervisions=[sup], - custom={"context_recording": context_rec}, - ).attach_tensor( - name=f"codes_{self.codec_model_name}", - data=target_code, - # temporal_dim=1, - # frame_shift=1 / self.codec_frame_rate - ).attach_tensor( - name=f"context_codes_{self.codec_model_name}", - data=context_code, - # temporal_dim=1, - # frame_shift=1 / self.codec_frame_rate - ) - for rec, sup, context_rec, target_code, context_code in zip( - target_recordings, - supervisions, - context_recordings, - prediction["target_codes"], - prediction["context_codes"], - ) - ] - - return cuts, context_recordings - - def teardown(self, trainer, pl_module, stage=None): - # Wait for rank 0 to finish writing - if trainer.world_size > 1: - torch.distributed.barrier() - - # Close the SharWriter and AudioTarWriter on rank 0 - if trainer.global_rank == 0: - # Write any remaining cuts in the buffer before closing - if self.cuts_buffer: - self._write_buffer() - - if self.context_recording_writer is not None: - self.context_recording_writer.close() - if self.shar_writer is not None: - self.shar_writer.close() - - -class AudioDataset(Dataset): - def __init__(self, manifest: str, audio_base_dir: str, sample_rate: int = 22050, pad_multiple: int = 1024): - self.audio_base_dir = audio_base_dir - self.sample_rate = sample_rate - self.pad_multiple = pad_multiple - self.items = read_manifest(manifest) - - def __len__(self): - return len(self.items) - - def get_wav_from_filepath(self, file_path: str): - features = AudioSegment.segment_from_file( - file_path, - target_sr=self.sample_rate, - n_segments=-1, - trim=False, - ) - audio = torch.tensor(features.samples) - audio = torch.nn.functional.pad(audio, (0, self.pad_multiple - audio.size(0) % self.pad_multiple), value=0) - audio_length = torch.tensor(audio.size(0)).long() - return audio, audio_length - - def __getitem__(self, idx): - item = self.items[idx] - if not check_speaker_format(item["speaker"]): - raise ValueError(f"Invalid speaker format at index {idx}: {item}") - target_audio_filepath = os.path.join(self.audio_base_dir, item["audio_filepath"]) - context_audio_filepath = os.path.join(self.audio_base_dir, item["context_audio_filepath"]) - target_audio, target_audio_length = self.get_wav_from_filepath(target_audio_filepath) - context_audio, context_audio_length = self.get_wav_from_filepath(context_audio_filepath) - output_dict = { - "target_audio_filepath": target_audio_filepath, - "target_audio": target_audio, - "target_audio_length": target_audio_length, - "target_audio_duration": item["duration"], - "context_audio_filepath": context_audio_filepath, - "context_audio": context_audio, - "context_audio_length": context_audio_length, - "context_audio_duration": item["context_audio_duration"], - "context_speaker_similarity": item["context_speaker_similarity"], - "speaker": item["speaker"], - "text": item["text"], - "language": item["speaker"].strip().split()[1].split(":")[-1], - } - # Extra useful metadata may exist in some manifests, so better to keep them for future usage. - return self._copy_maybe_extra_metadata(item, output_dict) - - def collate_fn(self, batch): - max_target_audio_length = max(item["target_audio_length"].item() for item in batch) - target_audios_padded = [ - torch.nn.functional.pad(item["target_audio"], (0, max_target_audio_length - item["target_audio"].size(0))) - for item in batch - ] - max_context_audio_length = max(item["context_audio_length"].item() for item in batch) - context_audios_padded = [ - torch.nn.functional.pad( - item["context_audio"], (0, max_context_audio_length - item["context_audio"].size(0)) - ) - for item in batch - ] - output_dict = { - # target audio - "target_audio_filepath": [item["target_audio_filepath"] for item in batch], - "target_audios": torch.stack(target_audios_padded), - "target_audio_lengths": torch.stack([item["target_audio_length"] for item in batch]), - "target_audio_durations": [item["target_audio_duration"] for item in batch], - # context audio - "context_audio_filepath": [item["context_audio_filepath"] for item in batch], - "context_audios": torch.stack(context_audios_padded), - "context_audio_lengths": torch.stack([item["context_audio_length"] for item in batch]), - "context_audio_durations": [item["context_audio_duration"] for item in batch], - "context_speaker_similarity": [item["context_speaker_similarity"] for item in batch], - # metadata - "speaker": [item["speaker"] for item in batch], - "text": [item["text"] for item in batch], - "language": [item["language"] for item in batch], - } - # Extra useful metadata may exist in some manifests, so better to keep them for future usage. - for meta_field in MAYBE_EXTRA_METADATA_IN_MANIFEST: - if meta_field not in batch[0]: - continue - output_dict[meta_field] = [item[meta_field] for item in batch] - return output_dict - - @staticmethod - def _copy_maybe_extra_metadata(input_dict: dict, output_dict: dict): - # Extra useful metadata may exist in some manifests, so better to keep them for future usage. - for meta_field in MAYBE_EXTRA_METADATA_IN_MANIFEST: - if meta_field in input_dict: - output_dict[meta_field] = input_dict[meta_field] - return output_dict - - -class CodecExtractor(pl.LightningModule): - def __init__(self, model_path: str): - super().__init__() - self.codec_model = AudioCodecModel.restore_from(restore_path=model_path, strict=False) - self.codec_model.eval() - - def forward(self, batch): - with torch.no_grad(): - target_codes, target_codes_lengths = self.codec_model.encode( - audio=batch["target_audios"], audio_len=batch["target_audio_lengths"] - ) - context_codes, context_codes_lengths = self.codec_model.encode( - audio=batch["context_audios"], audio_len=batch["context_audio_lengths"] - ) - return { - "target_codes": target_codes.cpu().type(torch.int16), - "target_codes_lengths": target_codes_lengths, - "context_codes": context_codes.cpu().type(torch.int16), - "context_codes_lengths": context_codes_lengths, - } - - def predict_step(self, batch, batch_idx): - codes_dict = self(batch) - target_codes = [ - codes[:, :codes_length] - for codes, codes_length in zip(codes_dict["target_codes"], codes_dict["target_codes_lengths"]) - ] - context_codes = [ - codes[:, :codes_length] - for codes, codes_length in zip(codes_dict["context_codes"], codes_dict["context_codes_lengths"]) - ] - batch.update( - { - "target_codes": target_codes, - "context_codes": context_codes, - } - ) - return batch - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--manifest", type=str) - parser.add_argument("--audio_base_dir", type=str) - parser.add_argument("--save_dir", type=str) - parser.add_argument("--codec_model_name", type=str, default="21fpsCausalDecoder") - parser.add_argument("--codec_model_path", type=str) - parser.add_argument("--codec_frame_rate", type=float, default=21.5) - parser.add_argument("--pad_multiple", type=int, default=1024) - parser.add_argument("--sample_rate", type=int, default=22050) - parser.add_argument("--devices", type=int, default=-1) - parser.add_argument("--num_nodes", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=48) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--shard_size", type=int, default=4096) - args = parser.parse_args() - - codec_extractor = CodecExtractor(args.codec_model_path) - - dataset = AudioDataset( - manifest=args.manifest, - audio_base_dir=args.audio_base_dir, - sample_rate=args.sample_rate, - pad_multiple=args.pad_multiple, - ) - dataloader = DataLoader( - dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, collate_fn=dataset.collate_fn - ) - - # Note that context_recording would be stored using AudioTarWriter. - pred_writer = SharPredictionWriter( - output_dir=args.save_dir, - codec_model_name=args.codec_model_name, - audio_base_dir=args.audio_base_dir, - codec_frame_rate=args.codec_frame_rate, - fields={ - "recording": "flac", - f"codes_{args.codec_model_name}": "numpy", - f"context_codes_{args.codec_model_name}": "numpy", - }, - shard_size=args.shard_size, - ) - - trainer = Trainer( - devices=args.devices, - accelerator="gpu", - strategy=DDPStrategy(find_unused_parameters=False), - num_nodes=args.num_nodes, - logger=False, - ) - # add writer callback to all gather batched predictions and write into shards. - trainer.callbacks.append(pred_writer) - - trainer.predict(codec_extractor, dataloaders=dataloader, return_predictions=False) diff --git a/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py b/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py new file mode 100644 index 000000000000..b200168b3a67 --- /dev/null +++ b/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py @@ -0,0 +1,504 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script requires the following updates to lhotse: add `shard_offset` in lhotse's writers. +$ pip install git+https://github.com/lhotse-speech/lhotse.git@883c24b5f6cdc4bbc73e89186e99f7907262b59c + +Example of manifest: + { + "audio_filepath": "train-clean-360/4098/11547/4098_11547_000032_000000.wav", + "text": "\"Isn't it?\" queried Theo.", + "speaker": "| Language:en Dataset:LibriTTS Speaker:4098 |", + "chapter_id": "11547", + "utter_id": "000032_000000", + "duration": 1.9700416666666667, + "normalized_text": "\"Isn't it?\" queried Theo.", + "context_speaker_similarity": 0.7800518870353699, + "context_audio_filepath": "train-clean-360/4098/11547/4098_11547_000031_000001.wav", + "context_audio_duration": 9.45 + } + +Example usage: + python scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py \ + --manifest-path ${MANIFEST} \ + --audio-base-dir ${AUDIO_BASE_DIR} \ + --output-dir ${OUTPUT_DIR} \ + --num-jobs ${NUM_JOBS} \ + --processing-chunk-size ${CHUNK_SIZE} \ + --audio-format ${AUDIO_FORMAT} \ + --log-level ${LOG_LEVEL} \ + 2>&1 | tee ./log/create_lhotse_shar_from_nemo_manifest.stdout + +Expected output: + $ tree ${OUTPUT_DIR} + ${OUTPUT_DIR}/ + cuts/ + cuts.000000.jsonl.gz + cuts.000001.jsonl.gz + ... + target_audio/ + recording.000000.tar + recording.000001.tar + ... + context_audio/ + recording.000000.tar + recording.000001.tar + ... +""" + +import argparse +import itertools +import logging +import math +import os +import re +from concurrent.futures import ProcessPoolExecutor, as_completed +from functools import partial +from pathlib import Path +from typing import Any, Dict, Tuple + +from lhotse import AudioSource, MonoCut, Recording, SupervisionSegment, compute_num_samples, fastcopy +from lhotse.serialization import load_jsonl +from lhotse.shar.writers import AudioTarWriter, JsonlShardWriter +from tqdm import tqdm + +NEMO_KEYS_NO_NEED_TO_LOG_IN_CUSTOM_FIELDS_FOR_SUPERVISION = [ + "audio_filepath", + "context_audio_filepath", + "text", + "offset", + "duration", + "speaker", +] + + +def to_shar_placeholder(recording: Recording, cut: MonoCut) -> Recording: + """this function is borrowed from lhotse.shar.writers.to_shar_placeholder. The only change for Recording instance is to update the id with cut.id.""" + return fastcopy( + recording, + id=cut.id, + # Creates a single AudioSource out of multiple ones. + sources=[AudioSource(type="shar", channels=recording.channel_ids, source="")], + # Removes the transform metadata because they were already executed. + transforms=None, + duration=cut.duration, + num_samples=compute_num_samples(cut.duration, recording.sampling_rate), + ) + + +def check_speaker_format(item: str): + """Enforce speaker format like '| Language:en Dataset:HiFiTTS Speaker:9136_other |'""" + pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" + if not isinstance(item, str): + return False + return bool(re.match(pattern, item)) + + +def get_recording_id(relative_path: str) -> str: + """Generate a recording ID from the relative audio path.""" + return "rec-" + relative_path.rsplit(".", 1)[0].replace("/", "-") + + +def process_manifest_entry(entry: Dict[str, Any], audio_base_dir: Path) -> Tuple[MonoCut, MonoCut] | None: + """ + Processes a single entry from the NeMo manifest to create Lhotse objects. + + Returns: + tuple: (target_cut, context_cut) or None if an error occurs. + """ + try: + # Required fields + target_audio_path_relative = entry.get("audio_filepath") + context_audio_path_relative = entry.get("context_audio_filepath") + target_audio_duration = entry.get("duration") + context_audio_duration = entry.get("context_audio_duration") + text = entry.get("text") + # observed cases when text is empty while normalized_text is not. + if not text or not text.strip(): + text = entry.get("normalized_text") + speaker = entry.get("speaker") + + # Check required fields + if not all( + [ + target_audio_path_relative, + context_audio_path_relative, + target_audio_duration, + context_audio_duration, + text, + speaker, + ] + ): + logging.warning(f"Skipping entry due to missing fields: {entry}") + return None + + # Check speaker format + if not check_speaker_format(speaker): + logging.warning(f"Skipping entry due to incorrect speaker format: {entry}") + return None + + target_audio_filepath = audio_base_dir / target_audio_path_relative + context_audio_filepath = audio_base_dir / context_audio_path_relative + + if not target_audio_filepath.is_file(): + logging.warning( + f"Skipping entry due to missing target audio file: {target_audio_filepath} from entry: {entry}" + ) + return None + if not context_audio_filepath.is_file(): + logging.warning( + f"Skipping entry due to missing context audio file: {context_audio_filepath} from entry: {entry}" + ) + return None + + # Create IDs + target_recording_id = get_recording_id(target_audio_path_relative) + context_recording_id = get_recording_id(context_audio_path_relative) + + # Create Recordings + # TODO: if input is FLAC, then we should set AudioSegment.from_file(int_values=True). Does this applies to lhotse? + target_recording = Recording.from_file(target_audio_filepath, recording_id=target_recording_id) + context_recording = Recording.from_file(context_audio_filepath, recording_id=context_recording_id) + + # Custom fields exist in manifests, so better to keep them for future usage. + custom_fields = { + key: val + for key, val in entry.items() + if key not in NEMO_KEYS_NO_NEED_TO_LOG_IN_CUSTOM_FIELDS_FOR_SUPERVISION + } + custom_fields["context_recording_id"] = context_recording_id + + # Extract language from speaker string + lang_match = re.search(r"Language:(\w+)", speaker) + language = lang_match.group(1) if lang_match else None + + # offset in seconds + target_offset_in_seconds = entry.get("offset", 0.0) + context_offset_in_seconds = entry.get("context_audio_offset", 0.0) + + # Create Supervision for target cut. We constrain one supervision per cut for now. + supervision = SupervisionSegment( + id=f"sup-{target_recording_id}", + recording_id=target_recording_id, + start=target_offset_in_seconds, + duration=target_audio_duration, # duration from manifest + channel=0, # only support mono audio for now + text=text, + language=language, + speaker=speaker, + custom=custom_fields, + ) + + # Create target cut + target_cut_id = f"cut-{target_recording_id}-{target_offset_in_seconds:.2f}-{target_audio_duration:.2f}" + target_cut = MonoCut( + id=target_cut_id, + start=target_offset_in_seconds, + duration=target_audio_duration, + channel=0, # only support mono audio for now + recording=target_recording, + supervisions=[supervision], + ) + if not math.isclose(target_cut.duration, target_audio_duration, abs_tol=0.1): + logging.warning( + f"Manifest duration ({target_audio_duration}) differs significantly from cut duration ({target_cut.duration}) for {target_recording_id}. Using cut duration." + ) + target_cut.supervisions[0].duration = target_cut.duration + + # Create context cut. This cut is only used to load segmented audio and would not be stored in the final manifest. + context_cut_id = ( + f"context_cut-{context_recording_id}-{context_offset_in_seconds:.2f}-{context_audio_duration:.2f}" + ) + if context_cut_id.split("-", 1)[1] == target_cut_id.split("-", 1)[1]: + logging.warning(f"Context cut has the same recording segment as target cut. Skipping entry: {entry}") + return None + + context_cut = MonoCut( + id=context_cut_id, + start=context_offset_in_seconds, + duration=context_audio_duration, + channel=0, # only support mono audio for now + recording=context_recording, + ) + return target_cut, context_cut + + except Exception as e: + logging.error(f"Skipping entry due to error during metadata processing: {entry}: {e}", exc_info=True) + return None + + +def chunked_iterator(iterable, chunk_size): + """Yield successive chunks from iterable.""" + _it = iter(iterable) + while _chunk := tuple(itertools.islice(_it, chunk_size)): + yield _chunk + + +def process_and_write_chunk( + manifest_chunk_with_idx: Tuple[int, Tuple[Dict[str, Any], ...]], + audio_base_dir: Path, + output_dir: Path, + audio_format: str, +) -> Dict[str, int]: + """ + Processes a chunk of manifest entries, loads audio, and writes corresponding + single shard files for cuts, target audio, and context audio. + Designed to be run in a parallel worker process. + Loads and writes audio iteratively to save memory. + + Returns a dict containing processing stats like 'processed', 'initial_errors', 'audio_load_errors'. + """ + chunk_idx, manifest_chunk = manifest_chunk_with_idx + worker_pid = os.getpid() + logging.debug(f"[Worker {worker_pid}, Chunk {chunk_idx}] Starting processing {len(manifest_chunk)} entries.") + + # --- 1. Process manifest entries to get Cut objects --- + chunk_metadata = [] + initial_errors = 0 + for entry in manifest_chunk: + result = process_manifest_entry(entry, audio_base_dir=audio_base_dir) + if result is not None: + chunk_metadata.append(result) + else: + initial_errors += 1 + + if not chunk_metadata: + logging.warning( + f"[Worker {worker_pid}, Chunk {chunk_idx}] No valid entries after initial processing. Skipping." + ) + return {"processed": 0, "initial_errors": initial_errors, "audio_load_errors": 0, "write_errors": 0} + + logging.debug( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Collected {len(chunk_metadata)} cut pairs after initial processing." + ) + + # --- 2. Initialize writers and perform iterative load-and-write --- + cuts_dir = output_dir / "cuts" + target_recordings_dir = output_dir / "target_audio" + context_recordings_dir = output_dir / "context_audio" + + cuts_pattern = str(cuts_dir / "cuts.%06d.jsonl.gz") + target_rec_pattern = str(target_recordings_dir / "recording.%06d.tar") + context_rec_pattern = str(context_recordings_dir / "recording.%06d.tar") + + chunk_processed_count = 0 + chunk_audio_load_errors = 0 # Errors during audio loading phase for this chunk + chunk_write_errors = 0 # Errors during write phase for this chunk + + logging.debug( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Initializing writers with offset {chunk_idx} and processing {len(chunk_metadata)} pairs iteratively..." + ) + try: + # Specify shard_size with len(chunk_metadata) and shard_offset with chunk_idx, ensuring each chunk is written to a separate shard file. + shard_size_for_worker = len(chunk_metadata) + with ( + JsonlShardWriter( + pattern=cuts_pattern, shard_size=shard_size_for_worker, shard_offset=chunk_idx + ) as cut_writer, + AudioTarWriter( + pattern=target_rec_pattern, + shard_size=shard_size_for_worker, + format=audio_format, + shard_offset=chunk_idx, + ) as target_rec_writer, + AudioTarWriter( + pattern=context_rec_pattern, + shard_size=shard_size_for_worker, + format=audio_format, + shard_offset=chunk_idx, + ) as context_rec_writer, + ): + # Iterate directly over chunk_metadata + for target_cut, context_cut in chunk_metadata: + # 1. load target/context audio given the audio offset + try: + target_audio = target_cut.load_audio() + context_audio = context_cut.load_audio() + except Exception as e: + logging.error( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Error loading target/context audio for cut {target_cut}: {e}", + exc_info=True, + ) + chunk_audio_load_errors += 1 + continue + + # 2. Write target audio and context audio + try: + target_rec_writer.write( + key=target_cut.id, + value=target_audio, + sampling_rate=target_cut.sampling_rate, + manifest=to_shar_placeholder( + target_cut.recording, target_cut + ), # update manifest.id with target_cut.id that has the audio offset and duration + ) + context_rec_writer.write( + key=target_cut.id, # use target cut id as key for context audio to ensure reference + value=context_audio, + sampling_rate=context_cut.sampling_rate, + manifest=to_shar_placeholder( + context_cut.recording, context_cut + ), # update manifest.id with context_cut.id that has the audio offset and duration + ) + except Exception as e: + logging.error( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Error writing target/context audio for target cut {target_cut}: {e}", + exc_info=True, + ) + chunk_write_errors += 1 + continue + + # 3. write cut metadata + try: + cut_writer.write(target_cut) + except Exception as e: + logging.error( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Error writing cut metadata for cut {target_cut}: {e}", + exc_info=True, + ) + chunk_write_errors += 1 + continue + + chunk_processed_count += 1 + + except Exception as e: + logging.error( + f"[Worker {worker_pid}, Chunk {chunk_idx}] CRITICAL error during writer initialization: {e}", exc_info=True + ) + chunk_write_errors = len(chunk_metadata) + chunk_processed_count = 0 + + # This part is only reached if the main try block completes without critical errors + logging.debug( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Finished chunk. Processed: {chunk_processed_count}, Audio Load Errors: {chunk_audio_load_errors}, Write Errors: {chunk_write_errors}" + ) + + return { + "processed": chunk_processed_count, + "initial_errors": initial_errors, # Errors from initial metadata processing + "audio_load_errors": chunk_audio_load_errors, + "write_errors": chunk_write_errors, + } + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Convert NeMo manifest to sharded Lhotse JSONL/TARs using parallel workers per chunk.", + ) + parser.add_argument("--manifest-path", required=True, type=Path, help="Path to the input NeMo JSON manifest file.") + parser.add_argument( + "--audio-base-dir", required=True, type=Path, help="Base directory where audio files are located." + ) + parser.add_argument("--output-dir", required=True, type=Path, help="Base directory to save the sharded outputs.") + parser.add_argument( + "--num-jobs", + type=int, + default=max(1, os.cpu_count() // 2), + help="Number of parallel worker processes (each processing a whole chunk/shard).", + ) + parser.add_argument( + "--processing-chunk-size", + type=int, + default=4000, + help="Number of manifest entries per chunk (effectively the items per output shard file).", + ) + parser.add_argument( + "--audio-format", type=str, default="flac", help="Audio format for TAR writers (e.g., flac, wav, opus)." + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level for the main process and workers.", + ) + + args = parser.parse_args() + + # Configure logging based on argument + log_level = getattr(logging, args.log_level.upper(), logging.INFO) + log_format = '%(asctime)s - PID:%(process)d - %(levelname)s - %(message)s' + logging.basicConfig(level=log_level, format=log_format) + + # Ensure output directories exist + cuts_dir = args.output_dir / "cuts" + target_recordings_dir = args.output_dir / "target_audio" + context_recordings_dir = args.output_dir / "context_audio" + cuts_dir.mkdir(parents=True, exist_ok=True) + target_recordings_dir.mkdir(parents=True, exist_ok=True) + context_recordings_dir.mkdir(parents=True, exist_ok=True) + + logging.info(f"Reading NeMo manifest lazily from: {args.manifest_path}") + manifest_iterable = load_jsonl(args.manifest_path) + + logging.info( + f"Processing manifest in chunks of {args.processing_chunk_size} using {args.num_jobs} parallel workers..." + ) + + total_processed_count = 0 + total_initial_errors = 0 + total_audio_load_errors = 0 + total_write_errors = 0 + num_chunks = 0 + + worker_func = partial( + process_and_write_chunk, + audio_base_dir=args.audio_base_dir, + output_dir=args.output_dir, + audio_format=args.audio_format, + ) + + with ProcessPoolExecutor(max_workers=args.num_jobs) as executor: + # Enumerate chunks to pass index to worker. Each index is the same as the shard_offset. + chunk_iterator = enumerate(chunked_iterator(manifest_iterable, args.processing_chunk_size)) + futures = { + executor.submit(worker_func, chunk_with_idx): chunk_with_idx[0] for chunk_with_idx in chunk_iterator + } + num_chunks = len(futures) + + logging.info(f"Submitted {num_chunks} chunks to workers.") + + for future in tqdm(as_completed(futures), total=num_chunks, desc="Processing Chunks"): + chunk_idx = futures[future] + try: + result = future.result() + total_processed_count += result["processed"] + total_initial_errors += result["initial_errors"] + total_audio_load_errors += result["audio_load_errors"] + total_write_errors += result["write_errors"] + logging.debug(f"Chunk {chunk_idx} finished with result: {result}") + except Exception as e: + logging.error(f"Chunk {chunk_idx} failed with exception: {e}", exc_info=True) + # Increment error count based on chunk size. Difficult to know precisely. Assume all failed. + total_initial_errors += args.processing_chunk_size + + logging.info("=" * 30 + " Processing Summary " + "=" * 30) + logging.info(f"Total chunks processed: {num_chunks}") + logging.info(f"Successfully processed and wrote data for approximately {total_processed_count} entries.") + total_errors = total_initial_errors + total_audio_load_errors + total_write_errors + if total_errors > 0: + logging.warning(f"Encountered errors/skips in {total_errors} potential entries:") + logging.warning(f" - Initial processing errors/skips: {total_initial_errors}") + logging.warning(f" - Audio loading errors/skips (affecting writes): {total_audio_load_errors}") + logging.warning(f" - Writing errors: {total_write_errors}") + logging.warning("Check logs above (use DEBUG level for more details) for specific entry issues.") + else: + logging.info("No significant errors reported.") + logging.info("Manifest creation finished.") + + +if __name__ == "__main__": + main() diff --git a/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py b/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py new file mode 100644 index 000000000000..56038108ad70 --- /dev/null +++ b/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py @@ -0,0 +1,912 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script extends the Lhotse shards with audio codec codes. + +Example of input shards: + $ tree ${CUTS_DIR} + ${CUTS_DIR}/ + cuts.000000.jsonl.gz + cuts.000001.jsonl.gz + ... + + $ tree ${TARGET_AUDIO_DIR} + ${TARGET_AUDIO_DIR}/ + recording.000000.tar + recording.000001.tar + ... + + $ tree ${CONTEXT_AUDIO_DIR} + ${CONTEXT_AUDIO_DIR}/ + recording.000000.tar + recording.000001.tar + ... + +Example usage: + export WANDB_API_KEY=${WANDB} + python -u ${CODE_DIR}/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py \ + --cuts-dir ${CUTS_DIR} \ + --target-audio-dir ${TARGET_AUDIO_DIR} \ + --context-audio-dir ${CONTEXT_AUDIO_DIR} \ + --output-dir ${RESULTS} \ + --codec-model-name ${CODEC_MODEL_NAME} \ + --codec-model-path ${CODEC_MODEL_PATH} \ + --codec-frame-rate ${CODEC_FRAME_RATE} \ + --devices ${DEVICES} \ + --num-nodes ${NUM_NODES} \ + --batch-size ${BATCH_SIZE} \ + --buffer-size ${BUFFER_SIZE} \ + --wandb-entity ${WANDB_ENTITY} \ + --wandb-project ${WANDB_PROJECT} \ + --wandb-name ${WANDB_NAME} \ + --log-level "DEBUG" \ + 2>&1 | tee ${LOG}/${WANDB_NAME}.stdout + +Expected output: + $ tree ${RESULTS} + ${RESULTS}/ + 21fpsCausalDecoder/ + target_codes/ + codes.000000.tar + codes.000001.tar + ... + context_codes/ + codes.000000.tar + codes.000001.tar + ... +""" + +import argparse +import glob +import logging +import os +import re +import threading +from collections import defaultdict +from concurrent.futures import Future, ThreadPoolExecutor +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import lightning.pytorch as pl +import torch +import wandb +from lhotse import CutSet +from lhotse.array import Array, TemporalArray +from lhotse.dataset import IterableDatasetWrapper, SimpleCutSampler +from lhotse.shar.writers.array import ArrayTarWriter +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import BasePredictionWriter +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.strategies import DDPStrategy +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from nemo.collections.tts.models import AudioCodecModel + + +def compute_effective_audio_length(original_audio_tensor: torch.Tensor, samples_per_frame: int) -> int: + """Computes the effective length of an audio tensor, padded to be a multiple of samples_per_frame.""" + original_len = original_audio_tensor.shape[0] + effective_len = original_len + if samples_per_frame > 0: + effective_len = ((original_len + samples_per_frame - 1) // samples_per_frame) * samples_per_frame + return effective_len + + +def collate_audio_vectors( + audio_list: List[torch.Tensor], audio_lens_list: List[int], padding_value: Union[float, int] +) -> torch.Tensor: + """ + Collate a list of audio vectors into a single tensor, handling padding for variable lengths. + Returns a padded tensor. + """ + assert all(len(t.shape) == 1 for t in audio_list), "Expected only 1-D input tensors." + assert len(audio_list) == len(audio_lens_list), "Expected the same number of audio vectors and lengths." + + # Create a padded tensor with the maximum audio length from audio_lens_list, where its max length could be longer than + # max length of `audio_list``. For example, `audio_lens_list` could be a multiple of the codec model samples per frame. + result = audio_list[0].new_ones(len(audio_lens_list), max(audio_lens_list)) * padding_value + for i, t in enumerate(audio_list): + result[i, : t.shape[0]] = t + return result + + +class AudioPairLhotseDataset(Dataset): + """ + A Lhotse Dataset that processes a batch of MonoCuts (received as a CutSet) + containing target and context audio. + Designed to be used with a Lhotse sampler yielding CutSet batches. + Handles loading audio and collating the batch within __getitem__. + """ + + def __init__(self, target_sample_rate: int, codec_model_samples_per_frame: int): + self.target_sample_rate = target_sample_rate + self.codec_model_samples_per_frame = codec_model_samples_per_frame + + def __getitem__(self, cuts: CutSet) -> Optional[Dict[str, Any]]: + original_target_audios_list = [] + effective_target_lengths_list = [] + original_context_audios_list = [] + effective_context_lengths_list = [] + target_cut_ids_list = [] + shard_indices_list = [] + + for cut in cuts: + if not cut.has_custom("shard_origin"): + err_msg = f"Cut {cut} is missing required key 'shard_origin'." + logging.error(err_msg) + raise ValueError(err_msg) + if not cut.has_custom("context_recording"): + err_msg = f"Cut {cut} is missing required key 'context_recording'." + logging.error(err_msg) + raise ValueError(err_msg) + + # Parse shard index from the custom field, handling potential errors + origin_path = cut.custom["shard_origin"] + match = re.search(r"cuts\.(\d+)\.jsonl\.gz$", origin_path) + if match is None: + raise ValueError(f"Could not parse shard index from shard_origin: {origin_path}") + shard_idx_origin = int(match.group(1)) + + # audio shape: (num_channels (1), num_samples) -> (num_samples) + # resample to target sample rate + target_audio = torch.from_numpy(cut.recording.resample(self.target_sample_rate).load_audio().squeeze(0)) + context_audio = torch.from_numpy( + cut.context_recording.resample(self.target_sample_rate).load_audio().squeeze(0) + ) + original_target_audios_list.append(target_audio) + original_context_audios_list.append(context_audio) + + eff_target_len = compute_effective_audio_length(target_audio, self.codec_model_samples_per_frame) + effective_target_lengths_list.append(eff_target_len) + + eff_context_len = compute_effective_audio_length(context_audio, self.codec_model_samples_per_frame) + effective_context_lengths_list.append(eff_context_len) + + target_cut_ids_list.append(cut.id) + shard_indices_list.append(shard_idx_origin) + + # Ensure lists are not empty before calling collate_audio_vectors. + if not original_target_audios_list: + err_msg = "AudioPairLhotseDataset.__getitem__ processed an empty CutSet or failed to load any audio data, resulting in an empty audio list." + logging.error(err_msg) + raise ValueError(err_msg) + + target_audio_padded_batch = collate_audio_vectors( + original_target_audios_list, effective_target_lengths_list, padding_value=0.0 + ) + context_audio_padded_batch = collate_audio_vectors( + original_context_audios_list, effective_context_lengths_list, padding_value=0.0 + ) + + # TODO: is it really necessary to convert lengths to torch.int64? currently applying torch.int32. + target_audio_lens_collated = torch.IntTensor(effective_target_lengths_list) + context_audio_lens_collated = torch.IntTensor(effective_context_lengths_list) + + return { + "target_audios": target_audio_padded_batch, + "target_audio_lens": target_audio_lens_collated, + "context_audios": context_audio_padded_batch, + "context_audio_lens": context_audio_lens_collated, + "target_cut_id": target_cut_ids_list, + "shard_idx_origin": shard_indices_list, + } + + +class CodecExtractor(pl.LightningModule): + """ + LightningModule to extract codec codes. Manages DataLoader creation and + distribution via predict_dataloader hook. + """ + + def __init__( + self, + model_path: str, + cuts_dir: str, + target_audio_dir: str, + context_audio_dir: str, + batch_size: int, + ): + super().__init__() + self.model_path = model_path + self.cuts_dir = Path(cuts_dir) + self.target_audio_dir = Path(target_audio_dir) + self.context_audio_dir = Path(context_audio_dir) + self.batch_size = batch_size + + logging.info(f"Initializing `AudioPairLhotseDataset` with model path: {self.model_path}") + # load the model. mapping to cpu is to avoid GPU mem spikes when initializing the model + self.codec_model = AudioCodecModel.restore_from(restore_path=self.model_path, map_location='cpu', strict=False) + self.codec_model.eval() + logging.info("Codec model loaded.") + + # Placeholder for the rank-specific list of dataloaders + self._rank_dataloaders: Optional[List[DataLoader]] = None + + def predict_dataloader(self) -> List[DataLoader]: + """ + Creates and returns the list of DataLoaders assigned to the current rank. + Caches the result to avoid redundant creation. + + This function is called by the Trainer to get the dataloaders for the current rank. This happens after + intializing `model.predict()` but before any actual prediction steps (ie. calls to `model.predict_step()`) are executed. + """ + # Return cached dataloaders if already created for this rank + if self._rank_dataloaders is not None: + return self._rank_dataloaders + + # Determine rank and world size + try: + # Prefer trainer attributes if available + current_global_rank = self.global_rank + world_size = self.trainer.world_size + except AttributeError: + # Fallback to torch.distributed if trainer attributes aren't set yet + current_global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + + logging.info(f"[Rank {current_global_rank}/{world_size}] Creating assigned subset of dataloaders...") + + # Find all shard files globally + cuts_shard_pattern = str(self.cuts_dir / "cuts.*.jsonl.gz") + all_cuts_shard_paths = sorted(glob.glob(cuts_shard_pattern)) + + if not all_cuts_shard_paths: + msg = f"[Rank {current_global_rank}/{world_size}] No input cut shards found matching pattern: {cuts_shard_pattern}. Cannot proceed." + logging.error(msg) + raise FileNotFoundError(msg) + + num_total_shards = len(all_cuts_shard_paths) + + # Verify shard indices are contiguous and start from 0 based on filenames (globally) + first_idx_str = re.search(r"cuts\.(\d+)\.jsonl\.gz$", all_cuts_shard_paths[0]).group(1) + last_idx_str = re.search(r"cuts\.(\d+)\.jsonl\.gz$", all_cuts_shard_paths[-1]).group(1) + first_idx = int(first_idx_str) + last_idx = int(last_idx_str) + expected_last_idx = num_total_shards - 1 + if first_idx != 0: + raise ValueError(f"Expected first shard index to be 0, but found {first_idx} in {all_cuts_shard_paths[0]}") + if last_idx != expected_last_idx: + raise ValueError( + f"Expected last shard index to be {expected_last_idx}, but found {last_idx} in {all_cuts_shard_paths[-1]}" + ) + logging.info( + f"[Rank {current_global_rank}/{world_size}] Verified {num_total_shards} total shard files globally, with indices from {first_idx} to {last_idx}." + ) + + # Calculate the slice of original shard indices assigned to this rank + is_distributed = world_size > 1 + assigned_shard_indices_for_rank = [] + + if num_total_shards > 0: + if not is_distributed: + assigned_shard_indices_for_rank = list(range(num_total_shards)) + logging.info( + f"[Rank {current_global_rank}/{world_size}] Non-distributed mode. Will process all {num_total_shards} shards." + ) + else: + num_per_rank_base = num_total_shards // world_size + num_with_extra = num_total_shards % world_size + + if current_global_rank < num_with_extra: + start_shard_offset = current_global_rank * (num_per_rank_base + 1) + num_shards_for_rank = num_per_rank_base + 1 + else: + # Offset by the shards handled by ranks with an extra one + start_shard_offset = num_with_extra + current_global_rank * num_per_rank_base + num_shards_for_rank = num_per_rank_base + + end_shard_offset = start_shard_offset + num_shards_for_rank + assigned_shard_indices_for_rank = list(range(start_shard_offset, end_shard_offset)) + + logging.info( + f"[Rank {current_global_rank}/{world_size}] Assigned original shard indices " + f"{start_shard_offset} through {end_shard_offset -1} " + f"({len(assigned_shard_indices_for_rank)} shards)" + ) + + if not assigned_shard_indices_for_rank: + logging.info( + f"[Rank {current_global_rank}/{world_size}] No shards assigned to this rank. Returning empty dataloader list. This usually happens when the number of shards is less than the number of ranks." + ) + self._rank_dataloaders = [] + return [] + + # Create DataLoaders only for the shards assigned to this rank + dataloaders_for_rank = [] + for original_shard_idx in tqdm( + assigned_shard_indices_for_rank, + total=len(assigned_shard_indices_for_rank), + desc=f">>> [Rank {current_global_rank}/{world_size}] Creating DataLoaders for its assigned shards", + ): + logging.debug(f"[Rank {current_global_rank}] Processing original shard {original_shard_idx}...") + fields = { + "cuts": [str(self.cuts_dir / f"cuts.{original_shard_idx:06d}.jsonl.gz")], + "recording": [str(self.target_audio_dir / f"recording.{original_shard_idx:06d}.tar")], + "context_recording": [str(self.context_audio_dir / f"recording.{original_shard_idx:06d}.tar")], + } + # Verify if all files exist + if not all(Path(shard_filepaths[0]).is_file() for shard_filepaths in fields.values()): + err_msg = f"[Rank {current_global_rank}/{world_size}] Missing one or more files for shard {original_shard_idx}. Files: {fields}" + logging.error(err_msg) + raise FileNotFoundError(err_msg) + + try: + logging.debug( + f"[Rank {current_global_rank}] Loading CutSet for original shard {original_shard_idx}..." + ) + shard_cutset = CutSet.from_shar(fields=fields) + logging.debug(f"[Rank {current_global_rank}] Loaded CutSet for original shard {original_shard_idx}.") + except Exception as e: + logging.critical( + f"[Rank {current_global_rank}/{world_size}] CRITICAL ERROR: Failed to load CutSet from shar for original shard index {original_shard_idx}. \ + Files attempted: {fields}. \ + Error: {e}", + exc_info=True, + ) + raise + + logging.debug(f"[Rank {current_global_rank}] Creating Sampler for original shard {original_shard_idx}...") + # Explicitly set rank=0, world_size=1 to ensure sampler iterates the whole shard_cutset + sampler = SimpleCutSampler( + shard_cutset, max_cuts=self.batch_size, shuffle=False, drop_last=False, rank=0, world_size=1 + ) + logging.debug(f"[Rank {current_global_rank}] Creating Dataset for original shard {original_shard_idx}...") + shard_dataset = AudioPairLhotseDataset( + target_sample_rate=self.codec_model.sample_rate, + codec_model_samples_per_frame=self.codec_model.samples_per_frame, + ) + logging.debug(f"[Rank {current_global_rank}] Wrapping Dataset for original shard {original_shard_idx}...") + iterable_dataset = IterableDatasetWrapper( + dataset=shard_dataset, + sampler=sampler, + ) + logging.debug( + f"[Rank {current_global_rank}] Creating DataLoader for original shard {original_shard_idx}..." + ) + dl = DataLoader( + dataset=iterable_dataset, + batch_size=None, + num_workers=1, # Keep num_workers=1 for `IterableDatasetWrapper + SimpleCutSampler` to avoid duplicate batches. + pin_memory=True, + ) + logging.debug( + f"[Rank {current_global_rank}] Appending DataLoader for original shard {original_shard_idx}..." + ) + dataloaders_for_rank.append(dl) + logging.debug(f"[Rank {current_global_rank}] Finished processing original shard {original_shard_idx}.") + + logging.info( + f"[Rank {current_global_rank}/{world_size}] Created {len(dataloaders_for_rank)} DataLoaders for this rank." + ) + # Cache the created dataloaders for this rank + self._rank_dataloaders = dataloaders_for_rank + return self._rank_dataloaders + + def forward( + self, + target_audios: torch.Tensor, + target_audio_lens: torch.Tensor, + context_audios: torch.Tensor, + context_audio_lens: torch.Tensor, + ) -> Optional[Dict[str, torch.Tensor]]: + try: + target_audios = target_audios.to(self.device) + target_audio_lens = target_audio_lens.to(self.device) + context_audios = context_audios.to(self.device) + context_audio_lens = context_audio_lens.to(self.device) + # NOTE: we avoided directly calling `self.codec_model.encode()` because it pads audios again. + with torch.inference_mode(): + target_audios_encoded, target_audios_encoded_len = self.codec_model.audio_encoder( + audio=target_audios, audio_len=target_audio_lens + ) + target_tokens = self.codec_model.quantize( + encoded=target_audios_encoded, encoded_len=target_audios_encoded_len + ) + context_audios_encoded, context_audios_encoded_len = self.codec_model.audio_encoder( + audio=context_audios, audio_len=context_audio_lens + ) + context_tokens = self.codec_model.quantize( + encoded=context_audios_encoded, encoded_len=context_audios_encoded_len + ) + return { + "target_codes": target_tokens.to(dtype=torch.int16, device="cpu"), + "target_codes_lengths": target_audios_encoded_len.to(device="cpu"), + "context_codes": context_tokens.to(dtype=torch.int16, device="cpu"), + "context_codes_lengths": context_audios_encoded_len.to(device="cpu"), + } + except Exception as e: + logging.error( + f"[Rank {self.global_rank}/{self.world_size}] Error during batched codec encoding: {e}", exc_info=True + ) + raise e + + def predict_step( + self, batch: Dict[str, Any], batch_idx: int, dataloader_idx: int = 0 + ) -> Optional[List[Dict[str, Any]]]: + codes_dict = self( + target_audios=batch["target_audios"], + target_audio_lens=batch["target_audio_lens"], + context_audios=batch["context_audios"], + context_audio_lens=batch["context_audio_lens"], + ) + + target_codes_batch = codes_dict["target_codes"] + target_codes_lens = codes_dict["target_codes_lengths"] + context_codes_batch = codes_dict["context_codes"] + context_codes_lens = codes_dict["context_codes_lengths"] + + target_cut_ids = batch["target_cut_id"] + shard_indices_in_batch = batch["shard_idx_origin"] + + # The shard_indices list should ideally contain the *same* original index + # for all items in a batch, because each DataLoader loads from only one shard. + results = [] + batch_size = batch["target_audios"].shape[0] + original_shard_idx = shard_indices_in_batch[0] + if not all(idx == original_shard_idx for idx in shard_indices_in_batch): + raise ValueError( + f"Inconsistent shard indices within batch! Batch Index: {batch_idx}, Dataloader Index: {dataloader_idx}. Indices: {shard_indices_in_batch}." + ) + + if len(target_cut_ids) != batch_size or target_codes_batch.shape[0] != batch_size: + raise ValueError( + f"Batch size mismatch after inference! Input IDs: {len(target_cut_ids)}, " + f"Input Audio Batch: {batch_size}, Output Codes Batch: {target_codes_batch.shape[0]}. " + f"Batch Index: {batch_idx}, Dataloader Index: {dataloader_idx}" + ) + + for target_cut_id, target_codes, context_codes, target_codes_len, context_codes_len in zip( + target_cut_ids, target_codes_batch, context_codes_batch, target_codes_lens, context_codes_lens + ): + results.append( + { + "target_cut_id": target_cut_id, + "shard_idx": original_shard_idx, + "target_codes": target_codes[:, :target_codes_len], + "context_codes": context_codes[:, :context_codes_len], + } + ) + + return results + + +class CodecPredictionWriter(BasePredictionWriter): + """ + Writes codec predictions (target and context codes) to ArrayTarWriter shards asynchronously. + Uses a ThreadPoolExecutor with a single worker to serialize writes and closing operations per shard, + allowing potential overlap between prediction computation and I/O while closing writers early. + """ + + def __init__( + self, + output_dir: str, + codec_model_name: str, + codec_frame_rate: float, + ): + super().__init__(write_interval="batch") + self.output_dir_base = Path(output_dir) + self.codec_model_name = codec_model_name + self.codec_frame_rate = codec_frame_rate + self.rank: int = -1 + self.world_size: int = -1 + self.target_writers: Dict[int, ArrayTarWriter] = {} + self.context_writers: Dict[int, ArrayTarWriter] = {} + self.target_codes_dir: Optional[Path] = None + self.context_codes_dir: Optional[Path] = None + + # Attributes for asynchronous writing and closing + self.writer_lock: Optional[threading.Lock] = None + self.bg_worker_thread: Optional[ThreadPoolExecutor] = None + self.futures_per_shard: Optional[Dict[int, List[Future]]] = None + self.closer_futures: Optional[List[Future]] = None # Futures for the _wait_and_close_worker tasks + self.last_processed_shard_idx: int = -1 + + def setup(self, trainer: Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None: + self.rank = trainer.global_rank + self.world_size = trainer.world_size + logging.info( + f"[Rank {self.rank}/{self.world_size}] Setting up CodecPredictionWriter for async writing with early close." + ) + + # Initialize async components + self.writer_lock = threading.Lock() + # Single worker ensures sequential execution of writes AND closes + self.bg_worker_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f'CodecWriterRank{self.rank}') + self.futures_per_shard = defaultdict(list) + self.closer_futures = [] + self.last_processed_shard_idx = -1 + + # Create directories + self.target_codes_dir = self.output_dir_base / self.codec_model_name / "target_codes" + self.context_codes_dir = self.output_dir_base / self.codec_model_name / "context_codes" + if self.rank == 0: + self.target_codes_dir.mkdir(parents=True, exist_ok=True) + self.context_codes_dir.mkdir(parents=True, exist_ok=True) + if trainer.world_size > 1: + torch.distributed.barrier() + logging.info(f"[Rank {self.rank}/{self.world_size}] Setup complete. Writers will be created on demand.") + + def _get_or_create_writer( + self, writer_dict: Dict[int, ArrayTarWriter], shard_idx: int, base_dir: Path + ) -> ArrayTarWriter: + # Lock needed as this might be called from main thread while closer task modifies dicts + with self.writer_lock: + if shard_idx not in writer_dict: + output_filename = str(base_dir / f"codes.{shard_idx:06d}.tar") + logging.debug( + f"[Rank {self.rank}/{self.world_size}] Creating writer for shard {shard_idx} (Thread-safe check): {output_filename}" + ) + try: + writer = ArrayTarWriter(pattern=output_filename, shard_size=None, compression="numpy") + writer.__enter__() + writer_dict[shard_idx] = writer + logging.info(f"[Rank {self.rank}/{self.world_size}] Created writer for shard {shard_idx}") + except Exception as e: + msg = f"[Rank {self.rank}/{self.world_size}] Failed to create writer for shard {shard_idx} (file: {output_filename}): {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + # Return writer even if it might be closed soon by a background task + # The background task handles the actual closing. + return writer_dict[shard_idx] + + def _write_worker( + self, + target_cut_id: str, + shard_idx: int, + target_codes: torch.Tensor, + context_codes: torch.Tensor, + target_writer: ArrayTarWriter, + context_writer: ArrayTarWriter, + ): + """Worker function executed by the background thread to write a single item.""" + # Assuming target_writer and context_writer are valid when this task starts + try: + target_codes_array_manifest = TemporalArray( + array=Array(storage_type="shar", storage_path="", storage_key="", shape=list(target_codes.shape)), + temporal_dim=-1, + frame_shift=1 / self.codec_frame_rate, + start=0, + ) + context_codes_array_manifest = TemporalArray( + array=Array(storage_type="shar", storage_path="", storage_key="", shape=list(context_codes.shape)), + temporal_dim=-1, + frame_shift=1 / self.codec_frame_rate, + start=0, + ) + target_writer.write(key=target_cut_id, value=target_codes.numpy(), manifest=target_codes_array_manifest) + context_writer.write(key=target_cut_id, value=context_codes.numpy(), manifest=context_codes_array_manifest) + logging.debug(f"[Worker Rank {self.rank}] Wrote item {target_cut_id} for shard {shard_idx}") + except Exception as e: + msg = f"[Worker Rank {self.rank}] CRITICAL I/O ERROR writing item {target_cut_id} for shard {shard_idx}: {e}. Writer might be closed prematurely?" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + def _wait_and_close_worker(self, shard_idx_to_close: int): + """Waits for all write tasks of a shard, then closes and removes its writers.""" + logging.info(f"[Worker Rank {self.rank}] Starting closure process for shard {shard_idx_to_close}") + # 1. Retrieve and remove the list of write futures for this shard + # Do this early to prevent new futures being added for this closing shard? + # No, write_on_batch_end logic prevents submission for old shards. + write_futures = self.futures_per_shard.pop(shard_idx_to_close, []) + + # 2. Wait for all write operations for this shard to complete + logging.info( + f"[Worker Rank {self.rank}] Waiting for {len(write_futures)} write tasks for shard {shard_idx_to_close}..." + ) + processed_write_futures = 0 + if write_futures: + for f in write_futures: + try: + f.result() # Wait for completion + processed_write_futures += 1 + except Exception as e: + # Write worker already logged this, but log context here + logging.error( + f"[Worker Rank {self.rank}] Exception during write future.result() for shard {shard_idx_to_close}: {e}", + exc_info=False, + ) + logging.info( + f"[Worker Rank {self.rank}] Completed {processed_write_futures}/{len(write_futures)} write tasks for shard {shard_idx_to_close}." + ) + else: + logging.warning( + f"[Worker Rank {self.rank}] No write futures found to wait for shard {shard_idx_to_close} during close." + ) + + # 3. Safely remove and close the writers + writers_closed_count = 0 + with self.writer_lock: # Protect access to the writer dictionaries + target_writer = self.target_writers.pop(shard_idx_to_close, None) + context_writer = self.context_writers.pop(shard_idx_to_close, None) + + if target_writer: + try: + target_writer.close() + logging.info(f"[Worker Rank {self.rank}] Closed target writer for shard {shard_idx_to_close}.") + writers_closed_count += 1 + except Exception as e: + logging.error( + f"[Worker Rank {self.rank}] Error closing target writer for shard {shard_idx_to_close}: {e}", + exc_info=True, + ) + else: + logging.warning( + f"[Worker Rank {self.rank}] Target writer for shard {shard_idx_to_close} not found during close." + ) + + if context_writer: + try: + context_writer.close() + logging.info(f"[Worker Rank {self.rank}] Closed context writer for shard {shard_idx_to_close}.") + writers_closed_count += 1 + except Exception as e: + logging.error( + f"[Worker Rank {self.rank}] Error closing context writer for shard {shard_idx_to_close}: {e}", + exc_info=True, + ) + else: + logging.warning( + f"[Worker Rank {self.rank}] Context writer for shard {shard_idx_to_close} not found during close." + ) + + logging.info( + f"[Worker Rank {self.rank}] Finished closure process for shard {shard_idx_to_close}. Closed {writers_closed_count} writers." + ) + + def write_on_batch_end( + self, + trainer: Trainer, + pl_module: pl.LightningModule, + predictions: Optional[List[Dict[str, Any]]], + batch_indices: Optional[List[int]], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + + if not predictions: + err_msg = f"[Rank {self.rank}/{self.world_size}] Received empty predictions list for batch_idx {batch_idx}, dataloader_idx {dataloader_idx}. Skipping." + logging.error(err_msg) + raise ValueError(err_msg) + + current_shard_idx = predictions[0]["shard_idx"] + if not all(p["shard_idx"] == current_shard_idx for p in predictions): + raise ValueError( + f"[Rank {self.rank}] Inconsistent shard indices within batch! Batch Index: {batch_idx}, Dataloader Index: {dataloader_idx}." + ) + + # Check for shard change and submit closer task for the previous shard + if current_shard_idx != self.last_processed_shard_idx and self.last_processed_shard_idx != -1: + logging.info( + f"[Rank {self.rank}] Shard index changed from {self.last_processed_shard_idx} to {current_shard_idx}. " + f"Submitting closure task for shard {self.last_processed_shard_idx}." + ) + try: + closer_future = self.bg_worker_thread.submit( + self._wait_and_close_worker, self.last_processed_shard_idx + ) + self.closer_futures.append(closer_future) + except Exception as e: + msg = f"[Rank {self.rank}] Failed to submit closer task for shard {self.last_processed_shard_idx}: {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + self.last_processed_shard_idx = current_shard_idx + + # Submit write tasks for each item in the current batch + for prediction in predictions: + try: + target_cut_id = prediction["target_cut_id"] + shard_idx = prediction["shard_idx"] + target_codes = prediction["target_codes"] + context_codes = prediction["context_codes"] + + # This needs the lock because the closer task might be removing entries concurrently + target_writer = self._get_or_create_writer(self.target_writers, shard_idx, self.target_codes_dir) + context_writer = self._get_or_create_writer(self.context_writers, shard_idx, self.context_codes_dir) + + # Submit the writing task + write_future = self.bg_worker_thread.submit( + self._write_worker, + target_cut_id, + shard_idx, + target_codes, + context_codes, + target_writer, + context_writer, + ) + self.futures_per_shard[shard_idx].append(write_future) + logging.debug(f"[Rank {self.rank}] Submitted write task for item {target_cut_id}, shard {shard_idx}") + + except Exception as e: + msg = f"[Rank {self.rank}] Error processing prediction item {prediction.get('target_cut_id', 'UNKNOWN')} from batch {batch_idx}: {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + def teardown(self, trainer: Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None: + logging.info( + f"[Rank {self.rank}/{self.world_size}] Tearing down CodecPredictionWriter. Handling final shard and waiting for closers..." + ) + + # 1. Submit closer task for the very last processed shard (if any) + final_shard_processed = self.last_processed_shard_idx + if final_shard_processed != -1 and final_shard_processed in self.futures_per_shard: + logging.info( + f"[Rank {self.rank}] Submitting final closure task for last processed shard {final_shard_processed}." + ) + try: + closer_future = self.bg_worker_thread.submit(self._wait_and_close_worker, final_shard_processed) + self.closer_futures.append(closer_future) + except Exception as e: + msg = f"[Rank {self.rank}] Failed to submit final closer task for shard {final_shard_processed}: {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + # 2. Wait for all closer tasks to complete + num_closer_futures = len(self.closer_futures) + logging.info( + f"[Rank {self.rank}/{self.world_size}] Waiting for {num_closer_futures} background closer tasks to complete." + ) + processed_closer_futures = 0 + if self.closer_futures: + for future in tqdm( + self.closer_futures, + total=num_closer_futures, + desc=f"[Rank {self.rank}/{self.world_size}] Finalizing Shard Closures", + leave=False, + ): + try: + future.result() # Wait and check for exceptions from the closer worker + processed_closer_futures += 1 + except Exception as e: + msg = f"[Rank {self.rank}/{self.world_size}] Exception caught during closer future.result(): {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + logging.info( + f"[Rank {self.rank}/{self.world_size}] Completed {processed_closer_futures}/{num_closer_futures} closer tasks." + ) + else: + logging.info(f"[Rank {self.rank}/{self.world_size}] No closer tasks were submitted.") + + # 3. Shutdown the executor gracefully (all tasks should be done now) + if self.bg_worker_thread: + logging.info(f"[Rank {self.rank}/{self.world_size}] Shutting down background worker thread.") + self.bg_worker_thread.shutdown(wait=True) + self.bg_worker_thread = None + + # 4. Final sanity checks and cleanup + remaining_writers = len(self.target_writers) + len(self.context_writers) + if remaining_writers > 0: + msg = f"[Rank {self.rank}/{self.world_size}] {remaining_writers} writers remain after teardown! This should not happen. Keys: Target {list(self.target_writers.keys())}, Context {list(self.context_writers.keys())}" + logging.error(msg) + raise ValueError(msg) + + remaining_futures = sum(len(futs) for futs in self.futures_per_shard.values()) + if remaining_futures > 0: + msg = f"[Rank {self.rank}/{self.world_size}] {remaining_futures} write futures remain after teardown! This should not happen. Shards: {list(self.futures_per_shard.keys())}" + logging.error(msg) + raise ValueError(msg) + + self.target_writers.clear() + self.context_writers.clear() + self.futures_per_shard.clear() + self.closer_futures.clear() + + logging.info(f"[Rank {self.rank}/{self.world_size}] Teardown complete.") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuts-dir", type=str, required=True, help="Directory containing input cuts/cuts.*.jsonl.gz shards." + ) + parser.add_argument( + "--target-audio-dir", type=str, required=True, help="Directory containing target_audio/recording.*.tar shards." + ) + parser.add_argument( + "--context-audio-dir", + type=str, + required=True, + help="Directory containing context_audio/recording.*.tar shards.", + ) + parser.add_argument("--output-dir", type=str, required=True, help="Base directory to save the output code shards.") + parser.add_argument( + "--codec-model-name", + type=str, + default="21fpsCausalDecoder", + help="Name for codec model (used in output path).", + ) + parser.add_argument( + "--codec-model-path", type=str, required=True, help="Path to the NeMo codec model (.nemo file)." + ) + parser.add_argument("--codec-frame-rate", type=float, default=21.5, help="Frame rate for codec model.") + parser.add_argument("--devices", type=int, default=-1, help="Number of GPUs per node (-1 for all).") + parser.add_argument("--num-nodes", type=int, default=1, help="Number of nodes for distributed processing.") + parser.add_argument("--batch-size", type=int, default=32, help="Batch size PER GPU for codec inference.") + parser.add_argument( + "--buffer-size", type=int, default=256, help="Number of items to buffer before writing to TAR files." + ) + parser.add_argument("--wandb-entity", type=str, default=None, help="Wandb entity.") + parser.add_argument("--wandb-project", type=str, default="lhotse_codes_extraction", help="Wandb project.") + parser.add_argument("--wandb-name", type=str, default=None, help="Wandb run name.") + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level.", + ) + args = parser.parse_args() + + log_level_val = getattr(logging, args.log_level.upper(), logging.INFO) + log_format = '%(asctime)s - PID:%(process)d - %(levelname)s - %(message)s' + logging.basicConfig(level=log_level_val, format=log_format) + + codec_extractor = CodecExtractor( + model_path=args.codec_model_path, + cuts_dir=args.cuts_dir, + target_audio_dir=args.target_audio_dir, + context_audio_dir=args.context_audio_dir, + batch_size=args.batch_size, + ) + + pred_writer = CodecPredictionWriter( + output_dir=args.output_dir, + codec_model_name=args.codec_model_name, + codec_frame_rate=args.codec_frame_rate, + ) + + wandb_logger = None + if args.wandb_entity and args.wandb_project: + wandb_logger = WandbLogger( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.wandb_name or f"extract_codes_{args.codec_model_name}_{os.path.basename(args.cuts_dir)}", + log_model=False, + ) + logging.info(f"Wandb logging enabled to {args.wandb_entity}/{args.wandb_project}") + + strategy = DDPStrategy(find_unused_parameters=False) if torch.cuda.is_available() and args.devices != 1 else "auto" + trainer = Trainer( + devices=args.devices if torch.cuda.is_available() else 1, + num_nodes=args.num_nodes, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + strategy=strategy, + logger=wandb_logger, + callbacks=[pred_writer], + use_distributed_sampler=False, # we should disable replacing or wrapping Lhostse CutSampler with a `DistributedSamplerWrapper` since Lhotse's sampler already handles distributed sampling. + ) + + logging.info(f"Starting prediction with {trainer.world_size} ranks.") + trainer.predict(codec_extractor, return_predictions=False) + logging.info("Prediction finished.") + + if trainer.is_global_zero and wandb_logger: + wandb.finish() + logging.info("Wandb run finished.") + + +if __name__ == "__main__": + import torch.multiprocessing + + try: + torch.multiprocessing.set_start_method('spawn') + except RuntimeError: + # This exception occurs if the start method has already been set. We can safely ignore it. + pass + main() diff --git a/scripts/magpietts/extend_nemo_manifest_with_context_audio.py b/scripts/magpietts/extend_nemo_manifest_with_context_audio.py new file mode 100644 index 000000000000..426fcac99d16 --- /dev/null +++ b/scripts/magpietts/extend_nemo_manifest_with_context_audio.py @@ -0,0 +1,873 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import re +import time +from collections import defaultdict +from pathlib import Path + +import lightning.pytorch as pl +import torch +import wandb +from lhotse.dataset.collation import collate_vectors +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.strategies import DDPStrategy +from torch.utils.data import DataLoader, Dataset + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment + +logger = logging.getLogger(__name__) + +""" +Usage: +python scripts/magpietts/extend_manifest_with_context_audio.py + --manifest /path/to/input.json + --audio-base-dir /path/to/audio + --output-dir /path/to/output_sharded_manifests + --batch-size 16 + --devices 2 + --num-nodes 1 + --flush-threshold-items 20000 + --num-workers 4 + --context-min-duration 3.0 + --context-min-ssim 0.6 + +This script distributes speakers across DDP ranks. Each rank processes its assigned speakers +and writes a partial manifest. Rank 0 then merges these into a final manifest. + +Input manifest example entry: +{ + "audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav", + "text": "the face.", + "speaker": "| Language:en Dataset:NVYT_2505 Speaker:_8Kirz57BTY_SPEAKER_01 |", + "offset": 2.8, + "duration": 0.48, + "bandwidth": 10125, + "stoi_squim": 0.98, + "sisdr_squim": 18.235, + "pesq_squim": 2.349, + "dataset_id": "369a9f1a-65eb-4c09-8de3-8babea29da4c", + "dataset_version": "2024_11_02_092919", + "dataset_name": "yt_mixed", + "normalized_text": "the face." +} + +Output manifest example entry: +{ + "audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav", + "text": "the face.", + "speaker": "| Language:en Dataset:NVYT_2505 Speaker:_8Kirz57BTY_SPEAKER_01 |", + "offset": 2.8, + "duration": 0.48, + "bandwidth": 10125, + "stoi_squim": 0.98, + "sisdr_squim": 18.235, + "pesq_squim": 2.349, + "dataset_id": "369a9f1a-65eb-4c09-8de3-8babea29da4c", + "dataset_version": "2024_11_02_092919", + "dataset_name": "yt_mixed", + "normalized_text": "the face.", + "context_audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav", + "context_audio_offset": 5.6, + "context_audio_duration": 6.0, + "context_audio_text": "would you mind..", + "context_audio_normalized_text": "would you mind..", + "context_audio_speaker_similarity": 0.85 +} +""" + + +def check_speaker_format(item: str): + """Enforce speaker format like '| Language:en Dataset:HiFiTTS Speaker:9136_other |'""" + pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" + if not isinstance(item, str): + return False + return bool(re.match(pattern, item)) + + +class SpeakerShardedAudioDataset(Dataset): + def __init__(self, assigned_records_list, base_audio_dir, sample_rate=16000): + self.sample_rate = sample_rate + self.base_audio_dir = base_audio_dir + self.processed_records = assigned_records_list + + def __len__(self): + return len(self.processed_records) + + def get_wav_from_filepath(self, audio_filepath_rel, offset_in_sec=0, duration_in_sec=None): + full_audio_filepath = os.path.join(self.base_audio_dir, audio_filepath_rel) + try: + features = AudioSegment.from_file( + audio_file=full_audio_filepath, + target_sr=self.sample_rate, + int_values=False, # TODO: if input is FLAC, then we should set this to True. + offset=offset_in_sec, + duration=duration_in_sec, + ) + except Exception as e: + logger.warning( + f"[Skipping Wav Load] Failed for `{full_audio_filepath}` (relative: `{audio_filepath_rel}`, offset={offset_in_sec}, duration={duration_in_sec}): {e}" + ) + return None, None + audio_samples = features.samples + return torch.tensor(audio_samples), torch.tensor(len(audio_samples)).long() + + def __getitem__(self, idx): + item_info = self.processed_records[idx] + + audio, audio_length = self.get_wav_from_filepath( + item_info["audio_filepath"], item_info["offset"], item_info["duration"] + ) + if audio is None or audio_length is None: + return None + + output_item = item_info.copy() + output_item.update( + { + "audio": audio, + "audio_length": audio_length, + } + ) + return output_item + + def collate_fn(self, batch): + valid_items = [item for item in batch if item is not None] + if not valid_items: + return { + "audios": torch.empty(0), + "audio_lengths": torch.empty(0), + "metadata_list": [], + "parsed_speaker_ids_list": [], + } + + audio_padded = collate_vectors([item["audio"] for item in valid_items], padding_value=0.0) + audio_lengths = torch.tensor([item["audio_length"] for item in valid_items]) + metadata_list = [ + {k: v for k, v in item.items() if k not in ['audio', 'audio_length', 'parsed_speaker_id']} + for item in valid_items + ] + parsed_speaker_ids_for_batch = [item['parsed_speaker_id'] for item in valid_items] + + return { + "audios": audio_padded, + "audio_lengths": audio_lengths, + "metadata_list": metadata_list, + "parsed_speaker_ids_list": parsed_speaker_ids_for_batch, + } + + +class EmbeddingSimilarityExtractorSharded(pl.LightningModule): + def __init__( + self, + output_dir: str, + output_file_prefix: str, + flush_threshold_items: int, + context_min_duration: float, + context_min_ssim: float, + speaker_expected_counts_map: dict, + initial_assigned_count: int, + ): + super().__init__() + self.sv_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + 'titanet_large', map_location=torch.device('cpu') + ) + self.sv_model.eval() + + self.output_dir = Path(output_dir) + self.output_file_prefix = output_file_prefix + self.flush_threshold_items = flush_threshold_items + self.context_min_duration = context_min_duration + self.context_min_ssim = context_min_ssim + self.speaker_expected_counts = speaker_expected_counts_map + self.initial_assigned_count = initial_assigned_count + + # Per-rank attributes + self.output_file_path = None + self.speaker_data_accumulator = defaultdict(list) + self.total_accumulated_items = 0 # total number of items accumulated across all speakers for this rank + self.processed_speakers_set = set() # set of speakers that have been processed and flushed + self.ready_to_flush_speaker_ids = set() # set of speakers that have accumulated enough items to be flushed + self.output_manifest_file = None + # total num of items discarded due to no suitable context for this rank + self.total_discarded_no_suitable_context_this_rank = 0 + self.total_items_written_this_rank = 0 # total items written to manifest by this rank + + def setup(self, stage: str): + if stage == "predict": + self.sv_model.to(self.device) + self.output_file_path = self.output_dir / f"{self.output_file_prefix}_rank{self.global_rank}.json" + self.output_dir.mkdir(parents=True, exist_ok=True) + self.output_manifest_file = open(self.output_file_path, "w", encoding="utf-8") + logger.info(f"Writing partial manifest to: `{self.output_file_path}`") + logger.debug(f"Expected speaker counts for model: {self.speaker_expected_counts}") + + def forward(self, batch): + with torch.no_grad(): + _, speaker_embeddings = self.sv_model.forward( + input_signal=batch['audios'], + input_signal_length=batch['audio_lengths'], + ) + return speaker_embeddings + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + if batch['audios'].nelement() == 0: + return [] + + speaker_embeddings_gpu = self(batch) + + output_items_for_batch_end = [] + for i, single_metadata_item in enumerate(batch["metadata_list"]): + embedding_cpu_fp32 = speaker_embeddings_gpu[i].cpu().type(torch.float32) + base_speaker_id_for_item = batch["parsed_speaker_ids_list"][i] + + processed_item = { + "speaker_id_for_grouping": base_speaker_id_for_item, + "embedding": embedding_cpu_fp32, + "metadata": single_metadata_item, + } + output_items_for_batch_end.append(processed_item) + + return output_items_for_batch_end + + def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): + for item in outputs: + base_speaker_id = item['speaker_id_for_grouping'] + + if base_speaker_id not in self.processed_speakers_set: + self.speaker_data_accumulator[base_speaker_id].append( + {'embedding': item['embedding'], 'metadata': item['metadata']} + ) + self.total_accumulated_items += 1 + + expected_count = self.speaker_expected_counts[base_speaker_id] + current_count = len(self.speaker_data_accumulator[base_speaker_id]) + + if current_count == expected_count: + self.ready_to_flush_speaker_ids.add(base_speaker_id) + logger.debug( + f"Speaker {base_speaker_id} is complete with {current_count} items. Added to `ready_to_flush_speaker_ids`." + ) + elif current_count > expected_count: + msg = f"Speaker {base_speaker_id} has {current_count} items, but expected {expected_count}. Possible data inconsistency or error in expected counts." + logger.error(msg) + raise ValueError(msg) + else: + msg = f"Received new item for already processed speaker '{base_speaker_id}'. This may indicate issues with data sharding, expected counts, or duplicate data." + logger.error(msg) + raise ValueError(msg) + + if self.total_accumulated_items >= self.flush_threshold_items and self.ready_to_flush_speaker_ids: + self._process_and_flush_speakers_local() + + def _process_and_flush_speakers_local(self): + speakers_to_process_now = list(self.ready_to_flush_speaker_ids) + self.ready_to_flush_speaker_ids.clear() + + if not speakers_to_process_now: + msg = "_process_and_flush_speakers_local called, but `speakers_to_process_now` is empty after list conversion. This is unexpected." + logger.error(msg) + raise ValueError(msg) + + logger.info( + f"Flushing {len(speakers_to_process_now)} completed speakers. " + f"Current total accumulated items: {self.total_accumulated_items}" + ) + + for speaker_id in speakers_to_process_now: + speaker_items = self.speaker_data_accumulator.pop(speaker_id) + self.total_accumulated_items -= len(speaker_items) + self.processed_speakers_set.add(speaker_id) + + # NOTE: Potential OOM (Out Of Memory) risk if a single speaker has an extremely large + # number of segments (e.g., tens of thousands). The N x N similarity matrix calculated below + # (where N = len(speaker_items)) can consume significant CPU RAM. + # For example, 50,000 segments for one speaker could lead to a float32 similarity matrix + # requiring approximately 10 GB of RAM. Consider this if processing datasets with + # speakers having a very high number of utterances. + embeddings = torch.stack([item['embedding'] for item in speaker_items]) + embeddings_norm = torch.nn.functional.normalize(embeddings, p=2, dim=1) + similarity_matrix = torch.matmul(embeddings_norm, embeddings_norm.transpose(0, 1)) + similarity_matrix.fill_diagonal_(-2.0) # cosine similarity range is [-1, 1] + + # Sort all similarities for each item to iterate through candidates + # best_similarities_tensor will contain sorted similarities for each row (original item) + # best_indices_tensor will contain original indices of these sorted items + sorted_similarities_tensor, sorted_indices_tensor = torch.sort(similarity_matrix, dim=1, descending=True) + + record_preparation_start_time = time.time() + num_records_written_for_speaker = 0 + # Initialize a counter for items discarded for this specific speaker + num_discarded_for_this_speaker_no_context = 0 + + for i, current_item_data in enumerate(speaker_items): + output_record = current_item_data['metadata'].copy() + write_this_record = False + + # Iterate through potential candidates, sorted by similarity + for candidate_rank in range(sorted_indices_tensor.size(1)): + candidate_ssim = sorted_similarities_tensor[i, candidate_rank].item() + original_candidate_idx = sorted_indices_tensor[i, candidate_rank].item() + + # Skip if candidate is the item itself (safeguard) + if original_candidate_idx == i: + continue + + # If SSIM is below threshold, stop searching for this item (since candidates are sorted) + if candidate_ssim < self.context_min_ssim: + break + + # Check duration if SSIM is acceptable + best_meta_dict = speaker_items[original_candidate_idx]['metadata'] + candidate_duration = best_meta_dict["duration"] + + if candidate_duration >= self.context_min_duration: + # Found a suitable candidate, update record and stop searching for this item + record_update_dict = { + "context_speaker_similarity": candidate_ssim, + "context_audio_filepath": best_meta_dict["audio_filepath"], + "context_audio_offset": best_meta_dict["offset"], + "context_audio_duration": candidate_duration, + "context_audio_text": best_meta_dict["text"], + } + normalized_text_candidate = best_meta_dict.get("normalized_text", None) + if normalized_text_candidate is not None: + record_update_dict["context_audio_normalized_text"] = normalized_text_candidate + + output_record.update(record_update_dict) + write_this_record = True + break + + if write_this_record: + self.output_manifest_file.write(json.dumps(output_record, ensure_ascii=False) + "\n") + num_records_written_for_speaker += 1 + else: + # This item will be discarded as no suitable context was found + num_discarded_for_this_speaker_no_context += 1 + + # Accumulate to rank-level total + self.total_discarded_no_suitable_context_this_rank += num_discarded_for_this_speaker_no_context + self.total_items_written_this_rank += num_records_written_for_speaker + + if len(speakers_to_process_now) > 0: + self.output_manifest_file.flush() # ensure all data currently held in the buffer is immediately written to disk. + logger.info(f"Flushing of completed speakers done. Local remaining items: {self.total_accumulated_items}") + + def on_predict_epoch_end(self): + logger.info( + f"Epoch end: Identifying remaining speakers to flush. " + f"Speakers in accumulator: {len(self.speaker_data_accumulator)}, Already processed: {len(self.processed_speakers_set)}" + ) + + for speaker_id, items in list(self.speaker_data_accumulator.items()): + if speaker_id not in self.processed_speakers_set: + expected_count = self.speaker_expected_counts[speaker_id] + actual_count = len(items) + if actual_count == expected_count: + logger.info( + f"Epoch end: Speaker {speaker_id} is complete ({actual_count}/{expected_count}). Adding to ready set." + ) + self.ready_to_flush_speaker_ids.add(speaker_id) + else: + msg = f"Epoch end: Speaker {speaker_id} is still in accumulator with {actual_count} items, but expected {expected_count}. This indicates an issue, e.g., not all data for this speaker was received or processed during the epoch." + logger.error(msg) + raise ValueError(msg) + + if self.ready_to_flush_speaker_ids: + logger.info( + f"Epoch end: Calling `_process_and_flush_speakers_local` for {len(self.ready_to_flush_speaker_ids)} ready speakers." + ) + self._process_and_flush_speakers_local() + else: + logger.info(f"Epoch end: No remaining speakers identified as ready to flush.") + + if self.speaker_data_accumulator: # Should be empty if all went well + msg = f"Epoch end: {len(self.speaker_data_accumulator)} speakers still in accumulator post-final flush attempt: {list(self.speaker_data_accumulator.keys())}" + logger.error(msg) + raise ValueError(msg) + + logger.info( + f"Total items discarded on this rank due to no suitable context found (failed SSIM or duration): {self.total_discarded_no_suitable_context_this_rank}" + ) + logger.info(f"Total items written to manifest on this rank: {self.total_items_written_this_rank}") + + # Verification step + expected_total_processed = ( + self.total_items_written_this_rank + self.total_discarded_no_suitable_context_this_rank + ) + if self.initial_assigned_count == expected_total_processed: + logger.info( + f"Verification successful: Initial items ({self.initial_assigned_count}) == Written ({self.total_items_written_this_rank}) + Discarded ({self.total_discarded_no_suitable_context_this_rank})" + ) + else: + msg = f"VERIFICATION FAILED: Initial items ({self.initial_assigned_count}) != Written ({self.total_items_written_this_rank}) + Discarded ({self.total_discarded_no_suitable_context_this_rank}) --- Difference: {self.initial_assigned_count - expected_total_processed}" + logger.error(msg) + raise RuntimeError(msg) + + if self.output_manifest_file: + self.output_manifest_file.close() + self.output_manifest_file = None + logger.info(f"Local processing complete. Partial manifest closed.") + + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Wait for all ranks to finish writing files + + +def _parse_speaker_id_libritts(record): + """ + libritts format: audio_filepath = "{subset}/{speaker_id}/{chapter_id}/{speaker_id}_{chapter_id}_{utterance_id}_{segment_id}.wav" + e.g. "train-clean-100/89/218/89_218_000014_000003.wav" + re-organized speaker_id: "{subset}_{speaker_id}_{chapter_id}" + e.g. "train-clean-100_89_218" + """ + parts = record['audio_filepath'].lower().split('/') + return f"{parts[0]}_{parts[1]}_{parts[2]}" + + +def _parse_speaker_id_hifitts(record): + """ + hifitts format: audio_filepath = "{speaker_id}_{audio_quality}/{book_id}/{chapter_name}_{segment_id}.wav" + e.g. "11614_other/12352/prideofjennico_01_castle_0000.flac" + re-organized speaker_id: "{speaker_id}_{audio_quality}_{book_id}_{chapter_name}" + e.g. "11614_other_12352_prideofjennico_01_castle" + """ + parts = record['audio_filepath'].lower().split('/') + chapter_name = parts[-1].rsplit('_', 1)[0] + return f"{parts[0]}_{parts[1]}_{chapter_name}" + + +def _parse_speaker_id_hifitts2(record): + """ + hifitts2 format: audio_filepath = "{speaker_id}/{book_id}/{speaker_id}_{book_id}_{chapter_name}_{segment_id}.wav" + e.g. "100/2315/100_2315_sea_fairies_0812_librivox-01_baum_sea_fairies_0.flac" + re-organized speaker_id: "{speaker_id}_{book_id}_{chapter_name}" + e.g. "100_2315_sea_fairies_0812_librivox-01_baum_sea_fairies" + """ + parts = record['audio_filepath'].lower().split('/') + return parts[-1].rsplit('_', 1)[0] + + +def _parse_speaker_id_nvyt2505(record): + """ + nvyt2505 format: audio_filepath = "NVYT_40K_audios_wav/{utterance_id}.wav", which does not contain speaker_id. + e.g. "NVYT_40K_audios_wav/Thg50o7gmsk.wav" + But we can parse the speaker_id from: speaker = "| Language:en Dataset:NVYT_2505 Speaker:Thg50o7gmsk_SPEAKER_00 |". + re-organized speaker_id: "{parsed_speaker_id}" + e.g. "thg50o7gmsk_speaker_00" + """ + speaker_regex = re.compile(r'Speaker:([^ |]+)') + match = speaker_regex.search(record['speaker']) + if not match: + raise ValueError(f"Failed to parse speaker_id from record: {record}") + return match.group(1).lower() + + +def _parse_speaker_id_rivaLindyRodney(record): + """ + rivaLindyRodney format: audio_filepath = "{speaker}/44khz/{emotion}/{speaker}_{emotion}_{utterance_id}.wav" + e.g. "Lindy/44khz/WIZWIKI/LINDY_WIZWIKI_004161.wav" + re-organized speaker_id: "{speaker}_{emotion}" + e.g. "lindy_wizwiki" + """ + parts = record['audio_filepath'].lower().split('/') + return f"{parts[0]}_{parts[2]}" + + +def _parse_speaker_id_rivaEmmaMeganSeanTom(record): + """ + rivaEmmaMeganSeanTom format: audio_filepath = "{speaker}/22_kHz/{speaker}_{emotion}_{utterance_id}.wav" + e.g. "Emma/22_kHz/Emma_Sad_Intense_Correlated_00147.wav" + re-organized speaker_id: "{speaker}_{emotion}" + e.g. "emma_sad_intense_correlated" + """ + parts = record['audio_filepath'].lower().split('/') + return parts[2].rsplit('_', 1)[0] + + +def _parse_speaker_id_jhsdGtc20Amp20Keynote(record): + """ + jhsdGtc20Amp20Keynote format: audio_filepath = "{keynote_event}_KEYNOTE-VOOnly-44khz-16bit-mono_{utterance_id}.wav" + e.g. "AMP20_KEYNOTE-VOOnly-44khz-16bit-mono_12.wav" + re-organized speaker_id: "{keynote_event}" + e.g. "AMP20" + """ + return record['audio_filepath'].lower().rsplit('_', 2)[0] + + +def _get_parsed_speaker_id_for_dataset(dataset_name_arg, record): + """Routes to the appropriate speaker ID parsing function based on dataset_name.""" + if dataset_name_arg == "libritts": + return _parse_speaker_id_libritts(record) + elif dataset_name_arg == "librittsDevClean": + return _parse_speaker_id_libritts(record) + elif dataset_name_arg == "hifitts": + return _parse_speaker_id_hifitts(record) + elif dataset_name_arg == "hifitts2": + return _parse_speaker_id_hifitts2(record) + elif dataset_name_arg == "nvyt2505": + return _parse_speaker_id_nvyt2505(record) + elif dataset_name_arg == "rivaLindyRodney": + return _parse_speaker_id_rivaLindyRodney(record) + elif dataset_name_arg == "rivaEmmaMeganSeanTom": + return _parse_speaker_id_rivaEmmaMeganSeanTom(record) + elif dataset_name_arg == "jhsdGtc20Amp20Keynote": + return _parse_speaker_id_jhsdGtc20Amp20Keynote(record) + else: + logger.error( + f"Unsupported dataset_name '{dataset_name_arg}' provided. Please check the --dataset-name argument." + ) + raise ValueError(f"Unsupported dataset_name: {dataset_name_arg}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--manifest", type=str, required=True) + parser.add_argument("--audio-base-dir", type=str, required=True) + parser.add_argument("--output-dir", type=str, required=True, help="Directory to save rank-specific manifests.") + parser.add_argument( + "--dataset-name", + type=str, + required=True, + choices=[ + "libritts", + "librittsDevClean", + "hifitts", + "hifitts2", + "nvyt2505", + "rivaLindyRodney", + "rivaEmmaMeganSeanTom", + "jhsdGtc20Amp20Keynote", + ], + help="Name of the dataset being processed. This determines the speaker ID parsing logic.", + ) + parser.add_argument("--flush-threshold-items", type=int, default=20000) + parser.add_argument( + "--context-min-duration", type=float, default=3.0, help="Minimum duration for a context audio segment." + ) + parser.add_argument( + "--context-min-ssim", type=float, default=0.6, help="Minimum cosine similarity for a context audio segment." + ) + parser.add_argument("--devices", type=int, default=-1) + parser.add_argument("--num-nodes", type=int, default=1) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--wandb-entity", type=str, default=None) + parser.add_argument("--wandb-project", type=str, default="speaker_similarity_sharded") + parser.add_argument("--wandb-name", type=str, default=None) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level.", + ) + args = parser.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level.upper()), + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Initialize DDP early to get rank and world_size for sharding + # PyTorch Lightning Trainer will handle DDP initialization if not done explicitly, + # but we need rank/world_size for data sharding before Trainer setup. + ddp_env_vars_detected = "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ + + user_intended_distributed = False + if isinstance(args.devices, int) and args.devices not in [0, 1]: # 0 for CPU, 1 for single GPU. -1 means all GPUs. + user_intended_distributed = True + if args.num_nodes > 1: + user_intended_distributed = True + + if user_intended_distributed and not ddp_env_vars_detected: + logger.warning( + f"Warning: A distributed run seems intended (num_nodes={args.num_nodes}, devices='{args.devices}'), " + f"but standard DDP environment variables (e.g., `LOCAL_RANK`, `WORLD_SIZE`) were not detected pre-Trainer initialization. " + f"If launching on SLURM, ensure you are using `srun` or have correctly configured your sbatch script. " + f"For local multi-GPU, consider using `torchrun`. " + f"PyTorch Lightning will now attempt to initialize the distributed environment. " + f"If it defaults to a single process, data sharding will be ineffective (all data processed by one rank)." + ) + + strategy = ( + DDPStrategy(find_unused_parameters=False) + if (isinstance(args.devices, int) and args.devices != 1 and args.devices != 0) + else "auto" + ) + + trainer = Trainer( + devices=args.devices, + num_nodes=args.num_nodes, + accelerator="gpu", + strategy=strategy, + logger=None, + max_epochs=1, + use_distributed_sampler=False, + ) + + world_size = trainer.world_size + global_rank = trainer.global_rank + + log_format = f"%(asctime)s [RANK {global_rank}] [%(levelname)s] %(message)s" + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.basicConfig(level=getattr(logging, args.log_level.upper()), format=log_format, datefmt="%Y-%m-%d %H:%M:%S") + + if global_rank == 0: + logger.info("Reading and sharding manifest ...") + + temp_sv_model_for_config = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + 'titanet_large', map_location=torch.device('cpu') + ) + # Initialize sample_rate for all ranks; rank 0 will populate it. + # This variable will be broadcast if in distributed mode. + sample_rate = temp_sv_model_for_config.preprocessor._sample_rate + min_duration_in_sec_required = temp_sv_model_for_config.preprocessor.featurizer.hop_length * 2 / sample_rate + del temp_sv_model_for_config + logger.info( + f"Calculated sample_rate: {sample_rate}, min_duration_in_sec_required: {min_duration_in_sec_required:.3f}s" + ) + + speaker_to_records = defaultdict(list) + num_processed_records = 0 + total_initial_records = 0 + + with open(args.manifest, "r", encoding="utf-8") as f: + for line in f: + total_initial_records += 1 + try: + rec = json.loads(line.strip()) + except json.JSONDecodeError: + logger.warning(f"Skipping malformed JSON line: `{line.strip()}`") + continue + + # 1. Apply duration filter + if rec.get("duration") is None or rec.get("duration") < min_duration_in_sec_required: + continue + + # 2. Apply speaker format check + if not check_speaker_format(rec["speaker"]): + msg = f"Invalid speaker format for record: {rec['speaker']}, File: {rec['audio_filepath']}(offset={rec['offset']}, duration={rec['duration']})." + logger.error(msg) + raise ValueError(msg) + + # 3. Parse speaker ID and add to map + rec['parsed_speaker_id'] = _get_parsed_speaker_id_for_dataset(args.dataset_name, rec) + + speaker_to_records[rec['parsed_speaker_id']].append(rec) + num_processed_records += 1 + + num_filtered_out_initial_pass = total_initial_records - num_processed_records + + logger.info( + f"Initial pass filtered out {num_filtered_out_initial_pass} records (e.g., duration). Processing {num_processed_records} records before speaker count filter." + ) + + # Filter speakers with less than 2 segments + speakers_before_count_filter = len(speaker_to_records) + + speakers_with_segment_counts = [ + {"count": len(rec_list), "records": rec_list} + for _, rec_list in speaker_to_records.items() + if len(rec_list) >= 2 + ] + del speaker_to_records + + speakers_after_count_filter = len(speakers_with_segment_counts) + records_after_count_filter = sum(item["count"] for item in speakers_with_segment_counts) + + num_speakers_filtered_by_count = speakers_before_count_filter - speakers_after_count_filter + num_records_filtered_by_speaker_count = num_processed_records - records_after_count_filter + + logger.info( + f"Filtered out {num_speakers_filtered_by_count} speakers (and {num_records_filtered_by_speaker_count} corresponding records) with < 2 segments. " + f"Now processing {records_after_count_filter} records from {speakers_after_count_filter} speakers for sharding." + ) + + # Greedy Bin-Packing for speaker distribution + # 1. Sort speakers by segment count in descending order + speakers_with_segment_counts.sort(key=lambda x: x["count"], reverse=True) + + # 2. Initialize rank loads and assignments + rank_loads = [0] * world_size + rank_assignments = [[] for _ in range(world_size)] + + # 3. Assign speakers to ranks using greedy approach + for speaker_info in speakers_with_segment_counts: + # Find the rank with the minimum current load + min_load_rank_idx = rank_loads.index(min(rank_loads)) + + # Assign all records of this speaker to that rank + rank_assignments[min_load_rank_idx].extend(speaker_info["records"]) + # Update the load of that rank + rank_loads[min_load_rank_idx] += speaker_info["count"] + + data_to_distribute = rank_assignments + logger.info( + f"Sharding complete. {sum(len(r) for r in data_to_distribute)} records distributed among {world_size} ranks." + ) + for r, recs in enumerate(data_to_distribute): + logger.info(f"Plan for rank {r} = {len(recs)} records.") + + per_rank_speaker_counts = [] # [{"spk_0": 10, "spk_1": 5}, {"spk_2": 6, "spk_3": 8}, ...] + for rank_idx in range(world_size): + counts_for_rank = defaultdict(int) + for record in data_to_distribute[rank_idx]: + counts_for_rank[record['parsed_speaker_id']] += 1 + per_rank_speaker_counts.append(dict(counts_for_rank)) + + else: # Other ranks prepare to receive + data_to_distribute = [None] * world_size + per_rank_speaker_counts = [None] * world_size + sample_rate = None # Initialize for non-rank 0 before broadcast + + # Broadcast the list of lists of records. Each rank will then pick its part. + if world_size > 1 and not torch.distributed.is_initialized(): + logger.warning( + f"Distributed run (world_size={world_size}) detected, but `torch.distributed` not yet initialized. " + f"Attempting to trigger environment setup via `trainer.strategy.setup_environment()`." + ) + # The trainer's strategy is responsible for setting up the distributed environment. + # This typically happens implicitly during trainer.fit/predict/test/validate calls. + trainer.strategy.setup_environment() + if torch.distributed.is_initialized(): + logger.info( + f"`torch.distributed` successfully initialized after `trainer.strategy.setup_environment()`. Synchronizing ranks." + ) + torch.distributed.barrier() # Ensure all ranks have completed setup before proceeding. + else: + msg = f"[Rank {global_rank}] FATAL: Failed to initialize `torch.distributed` even after calling `trainer.strategy.setup_environment()` for world_size={world_size}. Cannot proceed with distributed data sharding." + logger.error(msg) + raise RuntimeError(msg) + elif world_size == 1 and torch.distributed.is_initialized(): + # This case should ideally not happen (DDP initialized for a single process run by Lightning). + logger.warning(f"Warning: `torch.distributed` is initialized, but world_size is 1. This is unusual.") + elif world_size > 1 and torch.distributed.is_initialized(): + logger.info(f"`torch.distributed` was already initialized. world_size={world_size}. Synchronizing ranks.") + torch.distributed.barrier() + + # Now, proceed with the data distribution logic, expecting `torch.distributed` to be initialized if world_size > 1. + my_speaker_expected_counts = {} + if torch.distributed.is_initialized(): + torch.distributed.broadcast_object_list(data_to_distribute, src=0) + assigned_records_for_this_rank = data_to_distribute[global_rank] + torch.distributed.broadcast_object_list(per_rank_speaker_counts, src=0) + my_speaker_expected_counts = per_rank_speaker_counts[global_rank] + + # Broadcast sample_rate + if global_rank == 0: + sample_rate_to_broadcast = [sample_rate] + else: + sample_rate_to_broadcast = [None] + torch.distributed.broadcast_object_list(sample_rate_to_broadcast, src=0) + sample_rate = sample_rate_to_broadcast[0] + logger.info(f"Received {len(assigned_records_for_this_rank)} records for processing.") + logger.debug(f"Expected speaker counts for this rank: {my_speaker_expected_counts}") + logger.info(f"Received sample_rate via broadcast: {sample_rate}") + elif world_size == 1: + # data_to_distribute is already prepared by rank 0 code block if world_size was 1 from start + assigned_records_for_this_rank = data_to_distribute[0] if data_to_distribute and data_to_distribute[0] else [] + my_speaker_expected_counts = ( + per_rank_speaker_counts[0] if per_rank_speaker_counts and per_rank_speaker_counts[0] else {} + ) + if not assigned_records_for_this_rank: + msg = f"[Rank {global_rank}] Error: No records were assigned for processing in single process mode. Issue in initial data prep." + logger.error(msg) + raise ValueError(msg) + logger.info(f"Single process, assigned {len(assigned_records_for_this_rank)} records.") + logger.debug(f"Expected speaker counts: {my_speaker_expected_counts}") + logger.info(f"Using sample_rate from rank 0 execution: {sample_rate}") + else: + msg = f"[Rank {global_rank}] Critical: DDP not initialized for sharding, and not a single process run. Cannot determine configuration." + logger.error(msg) + raise ValueError(msg) + + # Validate that sample_rate is now available on all ranks before use + if sample_rate is None: + msg = f"[Rank {global_rank}] Critical error: sample_rate was not correctly set or broadcasted. Value is None." + logger.error(msg) + raise RuntimeError(msg) + + wandb_logger = None + if args.wandb_entity and args.wandb_project and global_rank == 0: + run_name = args.wandb_name or f"sharded_similarity_{Path(args.manifest).stem}" + wandb_logger = WandbLogger( + project=args.wandb_project, entity=args.wandb_entity, name=run_name, log_model=False + ) + logger.info(f"Wandb logging enabled to {args.wandb_entity}/{args.wandb_project}, run name: {run_name}") + trainer.logger = wandb_logger + + dataset = SpeakerShardedAudioDataset( + assigned_records_list=assigned_records_for_this_rank, + base_audio_dir=args.audio_base_dir, + sample_rate=sample_rate, + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + shuffle=False, + collate_fn=dataset.collate_fn, + pin_memory=True, + ) + + model = EmbeddingSimilarityExtractorSharded( + output_dir=args.output_dir, + output_file_prefix=Path(args.manifest).stem, + flush_threshold_items=args.flush_threshold_items, + context_min_duration=args.context_min_duration, + context_min_ssim=args.context_min_ssim, + speaker_expected_counts_map=my_speaker_expected_counts, + initial_assigned_count=len(assigned_records_for_this_rank), + ) + logger.info( + f"Starting prediction with {len(assigned_records_for_this_rank)} records ({len(my_speaker_expected_counts)} unique speakers for this rank according to counts)." + ) + trainer.predict(model, dataloaders=dataloader) + + # Rank 0 merges the partial manifests + if global_rank == 0: + final_manifest_path = Path(args.output_dir) / ( + Path(args.manifest).stem + + f"_withContextAudioMinDur{args.context_min_duration}MinSSIM{args.context_min_ssim}.json" + ) + logger.info(f"Merging partial manifest files to `{final_manifest_path}`...") + with open(final_manifest_path, "w", encoding="utf-8") as final_out_f: + for i in range(world_size): + partial_file_path = Path(args.output_dir) / f"{Path(args.manifest).stem}_rank{i}.json" + if partial_file_path.exists(): + with open(partial_file_path, "r", encoding="utf-8") as pf: + for line in pf: + final_out_f.write(line) + logger.info(f"Merged `{partial_file_path}`") + else: + logger.warning(f"Warning - partial manifest file not found: `{partial_file_path}`") + logger.info(f"Merging complete. Final manifest: `{final_manifest_path}`") + + if wandb_logger and global_rank == 0: + wandb.finish() + logger.info("WandB run finished.") + + logger.info(f"Done.") + + +if __name__ == "__main__": + main() From 77753ed3ccb3567fa1f72b8dd1d017c83387f8ee Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Wed, 18 Jun 2025 07:39:46 -0700 Subject: [PATCH 050/113] Bugfix: handle inference of models that don't have sample_rate in the config (#13955) Signed-off-by: Fejgin, Roy --- scripts/magpietts/infer_and_evaluate.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 59ae5c0e4f09..c9fee4f982da 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -77,8 +77,13 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False): else: model_cfg.forced_context_audio_eos_id = num_audio_tokens_per_codebook - 1 model_cfg.forced_context_audio_bos_id = num_audio_tokens_per_codebook - 2 - - return model_cfg + if hasattr(model_cfg, 'sample_rate'): + # This was removed from the config and is now in the model class + sample_rate = model_cfg.sample_rate + del model_cfg.sample_rate + else: + sample_rate = None + return model_cfg, sample_rate def update_ckpt(state_dict): new_state_dict = {} @@ -127,7 +132,7 @@ def run_inference( model_cfg = model_cfg.cfg with open_dict(model_cfg): - model_cfg = update_config(model_cfg, codecmodel_path, legacy_codebooks) + model_cfg, cfg_sample_rate = update_config(model_cfg, codecmodel_path, legacy_codebooks) model = MagpieTTSModel(cfg=model_cfg) model.use_kv_cache_for_inference = True @@ -141,13 +146,16 @@ def run_inference( elif nemo_file is not None: model_cfg = MagpieTTSModel.restore_from(nemo_file, return_config=True) with open_dict(model_cfg): - model_cfg = update_config(model_cfg, codecmodel_path, legacy_codebooks) + model_cfg, cfg_sample_rate = update_config(model_cfg, codecmodel_path, legacy_codebooks) model = MagpieTTSModel.restore_from(nemo_file, override_config_path=model_cfg) model.use_kv_cache_for_inference = True checkpoint_name = nemo_file.split("/")[-1].split(".nemo")[0] else: raise ValueError("Need a checkpoint") + if cfg_sample_rate is not None and cfg_sample_rate != model.sample_rate: + raise ValueError("Sample rate in config and model do not match") + print("Loaded weights.") model.cuda() model.eval() @@ -192,7 +200,7 @@ def run_inference( context_durration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though. test_dataset = MagpieTTSDataset( dataset_meta=dataset_meta, - sample_rate=model_cfg.sample_rate, + sample_rate=model.sample_rate, min_duration=0.5, max_duration=20, codec_model_samples_per_frame=model.codec_model_samples_per_frame, @@ -268,7 +276,7 @@ def run_inference( predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]] audio_path = os.path.join(pred_audio_dir, f"predicted_audio_{item_idx}.wav") - sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate) + sf.write(audio_path, predicted_audio_np, model.sample_rate) codes_path = os.path.join(pred_audio_dir, f"predicted_codes_{item_idx}.pt") torch.save(predicted_codes[idx][:predicted_codes_lens[idx]], codes_path) codec_file_paths.append(codes_path) @@ -332,6 +340,7 @@ def run_inference( with open(all_experiment_csv_with_ci, "a") as f: f.write(f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']},{metrics_mean_ci['frechet_codec_distance']}\n") print(f"Wrote metrics with CI for {checkpoint_name} and {dataset} to {all_experiment_csv_with_ci}") + measurements = [m['ssim_pred_context_avg'] for m in metrics_n_repeated] ssim = np.mean(measurements) @@ -342,7 +351,6 @@ def run_inference( shutil.rmtree(out_dir) return cer, ssim - def main(): parser = argparse.ArgumentParser(description='Experiment Evaluation') parser.add_argument('--hparams_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_hparams.yaml,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_hparams.yaml") From 1bccf0ca391630a0042c6144445707939fb6d64d Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Wed, 18 Jun 2025 11:42:48 -0700 Subject: [PATCH 051/113] use learnable position embeddings in cas encoder (#13909) * use learnable position embeddings in cas encoder Signed-off-by: Shehzeen Hussain * infer eval change Signed-off-by: Shehzeen Hussain --------- Signed-off-by: Shehzeen Hussain --- nemo/collections/tts/modules/magpietts_modules.py | 3 ++- scripts/magpietts/evalset_config.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index 9751b51508f7..eda50b5bfd1e 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -126,7 +126,8 @@ def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: d_ffn=d_embed * 4, sa_n_heads=8, kernel_size=1, - max_length_causal_mask=256 + max_length_causal_mask=256, + use_learnable_pos_emb=True ) @property diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index c7ad23724aaa..0b8c8c890126 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -27,6 +27,13 @@ 'audio_dir' : '/Data/RivaData/riva', 'feature_dir' : '/Data/RivaData/riva', }, + 'rough_qwen': { + 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/rough.json', + 'audio_dir' : '/Data/RivaData/riva', + 'feature_dir' : '/Data/RivaData/riva', + 'tokenizer_names': ['qwen'], + 'load_cached_codes_if_available': False, + }, 'riva_challenging_nozeros': { # 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/riva_challenging_nozeros.json', 'manifest_path': '/home/pneekhara/2023/SimpleT5NeMo/manifests/riva_challenging_filtered.json', From 09141850f09cb91d863d852d4f76ee72feebdfe4 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Wed, 18 Jun 2025 12:44:04 -0700 Subject: [PATCH 052/113] Magpietts Multilingual IPA GRPO (#13595) * yaml update Signed-off-by: Shehzeen Hussain * update preference optimization codec params Signed-off-by: Shehzeen Hussain * update raw texts for ipa GRPO Signed-off-by: Shehzeen Hussain * bug fix in transcribing non english languages using whisper Signed-off-by: Shehzeen Hussain * added dr grpo loss function and eval updates Signed-off-by: Shehzeen Hussain * log reward std Signed-off-by: Shehzeen Hussain * some datasets added to evalset Signed-off-by: Shehzeen Hussain * added config to scale the reward independent of dr grpo Signed-off-by: Shehzeen Hussain * exclude more models in load state dict Signed-off-by: Shehzeen Hussain * readme update Signed-off-by: Shehzeen Hussain * more models in ignore Signed-off-by: Shehzeen Hussain * bug fix - loss mask shape had changed from mask git code. so handled it in PO Signed-off-by: Shehzeen Hussain * clamp cer/wer between 0 and 1 during GRPO Signed-off-by: Shehzeen Hussain * Update nemo/collections/tts/models/magpietts_preference_optimization.py Co-authored-by: Jason Signed-off-by: Shehzeen Hussain * Update nemo/collections/tts/models/magpietts_preference_optimization.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Shehzeen Hussain * sample rate bug fix Signed-off-by: Shehzeen Hussain * sample rate key Signed-off-by: Shehzeen Hussain * added GRPO test compatible with old multiencoder checkpoint Signed-off-by: Shehzeen Hussain * added comments about raw text Signed-off-by: Shehzeen Hussain * added test in workflow Signed-off-by: Shehzeen Hussain * fix test script - handle = sign in checkpoint name Signed-off-by: Shehzeen Hussain * add more detailed documentation for GRPO Signed-off-by: Shehzeen Hussain * try whisper, since parakeet seems to fail loading in CI CD Signed-off-by: Shehzeen Hussain * remove pesq from test, torchaudio is not supported Signed-off-by: Shehzeen Hussain * change train val jsons for online po test Signed-off-by: Shehzeen Hussain --------- Signed-off-by: Shehzeen Hussain Signed-off-by: Shehzeen Hussain Co-authored-by: Jason Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .github/workflows/cicd-main.yml | 11 +- .../magpietts/magpietts_multilingual_v1.yaml | 31 ++-- .../tts/data/text_to_speech_dataset.py | 9 +- nemo/collections/tts/models/magpietts.py | 22 ++- .../magpietts_preference_optimization.py | 87 +++++++---- scripts/magpietts/README_magpie_po.md | 135 ++++++++++++++---- scripts/magpietts/evalset_config.py | 30 ++++ scripts/magpietts/evaluate_generated_audio.py | 6 +- ...L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh | 90 ++++++++++++ 9 files changed, 343 insertions(+), 78 deletions(-) create mode 100644 tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 120a15e12bac..b3e7bc53e3bd 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -1180,6 +1180,15 @@ jobs: SCRIPT: |- RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder.sh + L2_TTS_Fast_dev_runs_Magpietts_OnlinePO: + needs: [pre-flight, cicd-test-container-build] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_TTS_Fast_dev_runs_Magpietts_OnlinePO') + with: + RUNNER: self-hosted-azure + SCRIPT: |- + RUN_ID=${{ github.run_id }} bash tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh + L2_TTS_InferEvaluate_Magpietts_ZeroShot: needs: [pre-flight, cicd-test-container-build] uses: ./.github/workflows/_test_template.yml @@ -2161,4 +2170,4 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} verbose: true - flags: ${{ matrix.flag }} + flags: ${{ matrix.flag }} \ No newline at end of file diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index bec174f81146..df8dd230b724 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -14,7 +14,7 @@ val_ds_meta: ??? model: model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. + use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts context_duration_min: 3.0 @@ -25,12 +25,12 @@ model: prior_end_step: 12000 prior_scaledown_start_step: 8000 indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. - alignment_loss_scale: 0.0 + alignment_loss_scale: 0.002 embedding_dim: 768 codecmodel_path: ??? max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} - + cfg_unconditional_prob: 0.1 # Alignment encoder parameters, to binarize the prior # This is used for attention-constrained training and inference use_alignment_encoder: false @@ -123,24 +123,26 @@ model: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${train_ds_meta} weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} - min_duration: 0.5 + min_duration: 0.2 max_duration: 20.0 dataloader_params: batch_size: ${batch_size} num_workers: 4 drop_last: true + pin_memory: true validation_ds: dataset: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${val_ds_meta} - min_duration: 0.5 + min_duration: 0.2 max_duration: 20.0 dataloader_params: batch_size: ${batch_size} - num_workers: 0 + num_workers: 4 + pin_memory: true encoder: n_layers: 6 @@ -151,7 +153,7 @@ model: p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: false - is_causal: False + is_causal: true apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true @@ -175,12 +177,12 @@ model: d_model: 768 d_ffn: 3072 sa_n_heads: 12 - kernel_size: 3 + kernel_size: 1 p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: true xa_d_memory: 768 - xa_n_heads: 12 + xa_n_heads: 1 is_causal: true apply_norm_to_cond: true apply_norm_out: true @@ -188,9 +190,8 @@ model: use_learnable_pos_emb: true optim: - _target_: torch.optim.Adam - lr: 2e-4 - betas: [0.8, 0.99] + _target_: torch.optim.AdamW + lr: 1e-4 sched: name: ExponentialLR @@ -201,13 +202,14 @@ trainer: devices: -1 accelerator: gpu strategy: ddp_find_unused_parameters_true - precision: 32 + precision: bf16-mixed max_epochs: ${max_epochs} accumulate_grad_batches: 1 enable_checkpointing: False # Provided by exp_manager logger: false # Provided by exp_manager log_every_n_steps: 100 - val_check_interval: 500 + check_val_every_n_epoch: 1 + num_sanity_val_steps: 0 benchmark: false gradient_clip_val: 2.5 @@ -219,6 +221,7 @@ exp_manager: wandb_logger_kwargs: name: null project: null + resume: true create_checkpoint_callback: true checkpoint_callback_params: monitor: val_loss diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 00742047537f..47083c2b538d 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -606,7 +606,14 @@ def __getitem__(self, index): align_prior = torch.tensor(align_prior, dtype=torch.float32) example["align_prior"] = align_prior - example['raw_text'] = data.text + if "original_text" in data.manifest_entry: + # Raw Text is used as the GT for CER/WER computation in DPO pref data generation + # and GRPO reward setup. For manifests in which the 'text' field is phonemized, + # we use the 'original_text' field as the raw text. Otherwise, we use the regular text field. + example['raw_text'] = data.manifest_entry['original_text'] + else: + example['raw_text'] = data.text + example['language'] = data.manifest_entry.get('language', 'en') if "reward" in data.manifest_entry: diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 9b650ee4207b..fbfdfb2a5013 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -283,6 +283,22 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): del state_dict[key] return state_dict + def update_ckpt(self, state_dict): + """ + Backward compatibility for checkpoints saved with old model names. + """ + new_state_dict = {} + for key in state_dict.keys(): + if 't5_encoder' in key: + new_key = key.replace('t5_encoder', 'encoder') + new_state_dict[new_key] = state_dict[key] + elif 't5_decoder' in key: + new_key = key.replace('t5_decoder', 'decoder') + new_state_dict[new_key] = state_dict[key] + else: + new_state_dict[key] = state_dict[key] + return new_state_dict + def load_state_dict(self, state_dict, strict=True): """ Modify load_state_dict so that we don't restore weights to _speaker_verification_model and _codec_model when @@ -290,10 +306,14 @@ def load_state_dict(self, state_dict, strict=True): When strict is False, we can call pytorch's load_state_dict. When strict is True, we loop through all parameters and rename them to enable loading. """ + state_dict = self.update_ckpt(state_dict) if strict == False: super().load_state_dict(state_dict, strict=False) for name, child in self.named_children(): - if name in ['_speaker_verification_model', '_codec_model']: + if name in ['_speaker_verification_model', '_codec_model', '_reference_model', + 'eval_asr_model', 'eval_speaker_verification_model', + 'whisper_model', 'squim_objective_model' + ]: continue if any(param.numel() > 0 for param in child.parameters()): # If the module has parameters, we want to change the default mapping so that the state_dict gets diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index f778da20f8d0..d29cf5b1c48c 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -88,8 +88,8 @@ def test_step(self, batch, batch_idx): if not os.path.exists(audio_dir): os.makedirs(audio_dir) audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') - audio_durations.append(len(predicted_audio_np) / self.cfg.sample_rate) - sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) + audio_durations.append(len(predicted_audio_np) / self.sample_rate) + sf.write(audio_path, predicted_audio_np, self.sample_rate) predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) predicted_codes_torch = predicted_codes_torch[:, :predicted_codes_lens[idx]] @@ -280,9 +280,9 @@ def process_batch_dpo(self, batch_chosen_rejected): rejected_policy_logprobs = None chosen_ref_logprobs = None rejected_ref_logprobs = None - for codebook_idx in range(self.cfg.num_audio_codebooks): - si = codebook_idx * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook + for codebook_idx in range(self.num_audio_codebooks): + si = codebook_idx * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook codebook_logits_chosen = model_output_chosen['logits'][:, :, si:ei] codebook_logits_rejected = model_output_rejected['logits'][:, :, si:ei] @@ -292,11 +292,11 @@ def process_batch_dpo(self, batch_chosen_rejected): codebook_labels_chosen = model_output_chosen['audio_codes_target'][:,codebook_idx] codebook_labels_rejected = model_output_rejected['audio_codes_target'][:,codebook_idx] - codebook_log_probs_chosen = self._get_batch_logps(codebook_logits_chosen, codebook_labels_chosen, model_output_chosen['loss_mask']) - codebook_log_probs_rejected = self._get_batch_logps(codebook_logits_rejected, codebook_labels_rejected, model_output_rejected['loss_mask']) + codebook_log_probs_chosen = self._get_batch_logps(codebook_logits_chosen, codebook_labels_chosen, model_output_chosen['loss_mask'][:,codebook_idx]) + codebook_log_probs_rejected = self._get_batch_logps(codebook_logits_rejected, codebook_labels_rejected, model_output_rejected['loss_mask'][:,codebook_idx]) with torch.no_grad(): - ref_codebook_log_probs_chosen = self._get_batch_logps(ref_codebook_logits_chosen, codebook_labels_chosen, reference_model_output_chosen['loss_mask']) - ref_codebook_log_probs_rejected = self._get_batch_logps(ref_codebook_logits_rejected, codebook_labels_rejected, reference_model_output_rejected['loss_mask']) + ref_codebook_log_probs_chosen = self._get_batch_logps(ref_codebook_logits_chosen, codebook_labels_chosen, reference_model_output_chosen['loss_mask'][:,codebook_idx]) + ref_codebook_log_probs_rejected = self._get_batch_logps(ref_codebook_logits_rejected, codebook_labels_rejected, reference_model_output_rejected['loss_mask'][:,codebook_idx]) if chosen_policy_logprobs is None: chosen_policy_logprobs = codebook_log_probs_chosen @@ -435,6 +435,12 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # import ipdb; ipdb.set_trace() assert HAVE_TORCHAUDIO, "torchaudio is required for PESQ reward" self.squim_objective_model = SQUIM_OBJECTIVE.get_model() + + self.loss_type = self.cfg.get('loss_type', 'grpo') + if self.loss_type not in ['grpo', 'dr_grpo']: + raise ValueError(f"Received loss_type of {self.loss_type}, but the model only accepts one of ['grpo', 'dr_grpo']") + self.scale_rewards = self.cfg.get('scale_rewards', True) + self.max_decoder_steps = self.cfg.get('max_decoder_steps', 430) def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = super().state_dict(destination, prefix, keep_vars) @@ -444,8 +450,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): if any([substring in key for substring in keys_substrings_to_exclude]): del state_dict[key] return state_dict - - + def _get_per_token_logps(self, logits, labels, loss_mask): """Compute the log probabilities of the given labels under the given logits. @@ -486,7 +491,7 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): print("use_cfg", use_cfg) predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( batch_repeated, - max_decoder_steps=self.cfg.get('max_decoder_steps', 430), + max_decoder_steps=self.max_decoder_steps, temperature=temperature, topk=topk, use_cfg=use_cfg, @@ -506,8 +511,8 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): audio_dir = os.path.join(log_dir, 'audios') os.makedirs(audio_dir, exist_ok=True) audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') - audio_durations.append(len(predicted_audio_np) / self.cfg.sample_rate) - sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) + audio_durations.append(len(predicted_audio_np) / self.sample_rate) + sf.write(audio_path, predicted_audio_np, self.sample_rate) predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) predicted_codes_torch = predicted_codes_torch[:, :predicted_codes_lens[idx]] # C, T @@ -541,6 +546,8 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): gt_transcript = process_text_for_cer(batch_repeated['raw_texts'][idx]) cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True) wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False) + cer_gt = min(max(cer_gt, 0.0), 1.0) # Ensure CER is in [0, 1] + wer_gt = min(max(wer_gt, 0.0), 1.0) # Ensure WER is in [0, 1] spk_embedding_pred = pred_speaker_embeddings[idx].cpu().float().numpy() spk_embedding_gt = gt_speaker_embeddings[idx].cpu().float().numpy() spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( @@ -574,6 +581,8 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): best_ssim_achievable = self.cfg.get("best_ssim_achievable", 0.9) # Examples with this speaker similarity or higher will have SSIM reward of 1 mean_cer_dataset = self.cfg.get("mean_cer_dataset", 0.1) # CER equal to this value will have reward of 0.5 mean_ssim_dataset = self.cfg.get("mean_ssim_dataset", 0.6) # SSIM equal to this value will have reward of 0.5 + all_groups_mean_reward = 0.0 + all_groups_std_reward = 0.0 for group_idx in range(num_groups): group_start_idx = group_idx * num_generations_per_item group_end_idx = group_start_idx + num_generations_per_item @@ -615,15 +624,21 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): mean_reward /= num_generations_per_item std_reward = np.std(group_rewards) + all_groups_mean_reward += mean_reward + all_groups_std_reward += std_reward for idx in range(group_start_idx, group_end_idx): - batch_metrics[idx]['advantage'] = (batch_metrics[idx]['reward'] - mean_reward) / (std_reward + 1e-6) - + batch_metrics[idx]['advantage'] = batch_metrics[idx]['reward'] - mean_reward + if self.scale_rewards: + batch_metrics[idx]['advantage'] = batch_metrics[idx]['advantage'] / (std_reward + 1e-4) + all_groups_mean_reward = all_groups_mean_reward / num_groups + all_groups_std_reward = all_groups_std_reward / num_groups advantages = [x['advantage'] for x in batch_metrics] advantages = torch.tensor(advantages, device=self.device) - print("Mean reward: ", mean_reward) + print("Mean reward: ", all_groups_mean_reward) return { - 'mean_reward': torch.tensor(mean_reward, device=self.device), + 'mean_reward': torch.tensor(all_groups_mean_reward, device=self.device), + 'std_reward': torch.tensor(all_groups_std_reward, device=self.device), 'batch_repeated': batch_repeated, 'metrics': batch_metrics, 'predicted_codes': predicted_codes, @@ -671,27 +686,37 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): total_loss = None total_kl = None - for codebook_idx in range(self.cfg.num_audio_codebooks): - si = codebook_idx * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook + for codebook_idx in range(self.num_audio_codebooks): + policy_codebook_loss_mask = policy_model_outputs['loss_mask'][:,codebook_idx,:] + reference_codebook_loss_mask = reference_model_output['loss_mask'][:,codebook_idx,:] if not self.reference_free else None + si = codebook_idx * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook codebook_logits = policy_model_outputs['logits'][:, :, si:ei] # B, T, C codebook_labels = batch_repeated['audio_codes'][:,codebook_idx,1:] - per_token_codebook_log_probs = self._get_per_token_logps(codebook_logits, codebook_labels, policy_model_outputs['loss_mask']) + + per_token_codebook_log_probs = self._get_per_token_logps(codebook_logits, codebook_labels, policy_codebook_loss_mask) per_token_loss = -(torch.exp(per_token_codebook_log_probs - per_token_codebook_log_probs.detach()) * advantages.unsqueeze(1)) if not self.reference_free: with torch.no_grad(): ref_codebook_logits = reference_model_output['logits'][:, :, si:ei] - per_token_ref_codebook_log_probs = self._get_per_token_logps(ref_codebook_logits, codebook_labels, reference_model_output['loss_mask']) + per_token_ref_codebook_log_probs = self._get_per_token_logps(ref_codebook_logits, codebook_labels, reference_codebook_loss_mask) # https://github.com/huggingface/trl/blob/ffcb9f4aee725a2bd072d0387afe68a4b1c7967c/trl/trainer/grpo_trainer.py#L703 per_token_codebook_kl = torch.exp(per_token_ref_codebook_log_probs - per_token_codebook_log_probs) - (per_token_ref_codebook_log_probs - per_token_codebook_log_probs) - 1 per_token_loss = per_token_loss + self.cfg.grpo_beta * per_token_codebook_kl - codebook_kl_loss_mean = ((per_token_codebook_kl * policy_model_outputs['loss_mask']).sum(dim=1) / policy_model_outputs['loss_mask'].sum(dim=1)).mean() + codebook_kl_loss_mean = ((per_token_codebook_kl * policy_codebook_loss_mask).sum(dim=1) / policy_codebook_loss_mask.sum(dim=1)).mean() else: codebook_kl_loss_mean = torch.tensor(0.0, device=self.device) - codebook_loss = ((per_token_loss * policy_model_outputs['loss_mask']).sum(dim=1) / policy_model_outputs['loss_mask'].sum(dim=1)).mean() + if self.loss_type == "grpo": + codebook_loss = ((per_token_loss * policy_codebook_loss_mask).sum(dim=1) / policy_codebook_loss_mask.sum(dim=1)).mean() + elif self.loss_type == "dr_grpo": + # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py + total_tokens = per_token_loss.shape[0] * self.max_decoder_steps + codebook_loss = (per_token_loss * policy_codebook_loss_mask).sum() / total_tokens + else: + raise ValueError(f"Unknown loss function: {self.loss_type}") if total_loss is None: total_loss = codebook_loss @@ -700,11 +725,11 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): total_loss += codebook_loss total_kl += codebook_kl_loss_mean - - total_loss /= self.cfg.num_audio_codebooks - print("Total kl", total_kl, n_generations_per_item) + total_loss /= self.num_audio_codebooks + return { 'mean_reward': generated_codes_and_metrics['mean_reward'], + 'std_reward': generated_codes_and_metrics['std_reward'], 'loss': total_loss, 'kl_loss': total_kl, 'batch_metrics': generated_codes_and_metrics['metrics'], @@ -717,6 +742,7 @@ def training_step(self, batch, batch_idx): self.log('train_loss', po_outputs['loss'], prog_bar=True, sync_dist=True) self.log('train_kl_loss', po_outputs['kl_loss'], prog_bar=True, sync_dist=True) self.log('train_mean_reward', po_outputs['mean_reward'], prog_bar=True, sync_dist=True) + self.log('train_std_reward', po_outputs['std_reward'], prog_bar=True, sync_dist=True) return po_outputs['loss'] def validation_step(self, batch, batch_idx): @@ -728,6 +754,7 @@ def validation_step(self, batch, batch_idx): self.validation_step_outputs.append({ 'mean_reward': mean_reward, + 'std_reward': po_outputs['std_reward'], 'val_loss': val_loss, 'val_kl_loss': val_kl_loss, 'batch_metrics': batch_metrics, @@ -747,10 +774,12 @@ def collect(key): val_loss = collect("val_loss") val_kl_loss = collect("val_kl_loss") mean_reward = collect("mean_reward") + std_reward = collect("std_reward") self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) self.log("val_kl_loss", val_kl_loss, prog_bar=True, sync_dist=True) self.log("val_mean_reward", mean_reward, prog_bar=True, sync_dist=True) + self.log("val_std_reward", std_reward, prog_bar=True, sync_dist=True) mean_metrics = {} for val_output in self.validation_step_outputs: @@ -822,7 +851,7 @@ def get_speaker_embeddings_from_filepaths(filepaths, speaker_verification_model, def transcribe_with_whisper(audio_filepath, language, whisper_processor, whisper_model, device): speech_array, sampling_rate = librosa.load(audio_filepath, sr=16000) - forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language) if language else None + forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features inputs = inputs.to(device) with torch.no_grad(): diff --git a/scripts/magpietts/README_magpie_po.md b/scripts/magpietts/README_magpie_po.md index d6d243e81899..d2dab048e1ec 100644 --- a/scripts/magpietts/README_magpie_po.md +++ b/scripts/magpietts/README_magpie_po.md @@ -134,45 +134,120 @@ python scripts/magpietts/dpo/create_text_contextpairs.py \ 2. Train using GRPO +To train with GRPO, we use a similar training command as the base model training with a few modifications. + +1. We start from a pretrained checkpoint supplied using `+init_from_ptl_ckpt` +2. We add `+mode="onlinepo_train"` to specify preference optimization based training. +3. Use a small batch size (bs=2) since we generate `num_generations_per_item` samples per item in the batch and the effective batch size becomes `bs*num_generations_per_item` +4. The manifest should contain absolute audio paths and the `audio_dir` is specified as "/" in the `train_ds_meta` command. +5. Use the same model specific overrides as the base model (eg. x-attn heads, is_causal, num_layers, local transformer etc). +6. Set dropout probs to 0 for all modules - This is especially important if we are not using reference free mode. KL divergence loss becomes very spiky and unstable. Set prob to 0 by `model.decoder.p_dropout=0.0`. +7. Dont use attention prior or CTC loss during GRPO. +8. Add the following GRPO specific arguments in the training command. + +``` ++model.grpo_beta=0.0 \ # Coeffecient for KL loss (if not using reference free mode) ++model.num_generations_per_item=12 \ # 12 samples generated for each item and we compute reward for each ++model.reference_free=true \ # Reference free means we dont use KL loss term. Only optimize for rewards ++model.inference_cfg_prob=0.0 \ # fraction of generations generated using CFG. Can set > 0.0 if we want to optimize for both CFG and non CFG modes of generation ++model.inference_cfg_scale=2.5 \ # CFG scale for samples generated using CFG ++model.cer_reward_weight=0.33 \ # weightage of CER reward in the overall reward ++model.ssim_reward_weight=0.33 \ # weightage of SSIM reward in the overall reward ++model.pesq_reward_weight=0.33 \ # weightage of PESQ reward in the overall reward ++model.use_pesq=true \ # set this is true is using pesq reward ++model.reward_asr_model="whisper" \ # Use whisper only for multilingual settings, dont specify for English +model.cfg_unconditional_prob=0.0 \ # Set this to 0, we dont want want to drop out unconditional input ++model.inference_topk=2016 \ # Top-K - Not yet sure if we should use topk=80 or not. top_k 2016 just disable top_k in a way. ++model.inference_temperature=0.8 \ # Slightly higher temperature for more variety of generations in preference optimization ++model.use_kv_cache_during_online_po=true \ # Use KV caching while generating samples for GRPO ++model.loss_type="grpo" \ # can be grpo or dr_grpo. grpo works better in my experiments. ++model.scale_rewards=true \ # Whether to divide advantages by std deviation or not (set true for GRPO and false for DR_GRPO) ++model.max_decoder_steps=430 \ # Max steps for generation +``` + +9. We also want to validate more frequently during GRPO since each step takes longer. So we add the following args. +``` +~trainer.check_val_every_n_epoch \ ++trainer.val_check_interval=50 \ +``` + +10. We use a lower learning rate and save the best checkpoints based on lowest CER on our validation set using: +``` +model.optim.lr=1e-7 \ +~model.optim.sched \ +exp_manager.checkpoint_callback_params.monitor="val_cer_gt" \ +exp_manager.checkpoint_callback_params.mode="min" \ +``` + +11. Specify precision and gradient clipping as necessary +``` +trainer.precision=32 \ ++trainer.gradient_clip_val=2.5 \ +``` + + +Below is a sample training command for multilingual GRPO: + ``` python examples/tts/magpietts.py \ +--config-name=magpietts_multilingual_v1 \ +batch_size=2 \ ++init_from_ptl_ckpt="/mountdir/checkpoints/magpie_checkpoints/shared_char_ipa_epoch285.ckpt" \ +mode="onlinepo_train" \ -+init_from_ptl_ckpt="/Data/ICML2025_CKPTS/icml2025_base_checkpoints/decodercontext_small_sp_ks3CorrectWithPrior_onlyphoneme_epoch161.ckpt" \ -max_epochs=1000 \ -exp_manager.exp_dir="/Data/Experiments/NewT5TTSGRPO/Try3NoDropoutBeta0.01_CFG/" \ -+train_ds_meta.grpotrainnomls.manifest_path="/Data/DPOPairsInputDatav2/text_context_pairs_grpo_train_nomls.json" \ -+train_ds_meta.grpotrainnomls.audio_dir="/" \ -+train_ds_meta.grpotrainnomls.feature_dir="/" \ -+val_ds_meta.grpovalnomls.manifest_path="/Data/DPOPairsInputDatav2/text_context_pairs_grpo_val_unseenspeakers_tinysubset.json" \ -+val_ds_meta.grpovalnomls.audio_dir="/" \ -+val_ds_meta.grpovalnomls.feature_dir="/" \ -+model.num_generations_per_item=6 \ -+model.grpo_beta=0.01 \ +~model.text_tokenizers.multilingual_sentencepiece \ ++model.text_tokenizers.chartokenizer._target_=AutoTokenizer \ ++model.text_tokenizers.chartokenizer.pretrained_model="google/byt5-small" \ +max_epochs=20 \ +exp_manager.exp_dir="${DOCKER_EXP_DIR}" \ ++exp_manager.version=0 \ +exp_manager.checkpoint_callback_params.always_save_nemo=false \ ++train_ds_meta.dpopreftrain.manifest_path="/data/TTS/CML/manifests_with_codecs_ipa3/cml_tts_dataset_portuguese_v0.1/train_withAudioCodes_codec21KhzCausalDecoder_filtered_textcontextpairs_train_GRPO_ipa_NoDuplicates.json" \ ++train_ds_meta.dpopreftrain.audio_dir="/" \ ++train_ds_meta.dpopreftrain.feature_dir="/" \ ++train_ds_meta.dpopreftrain.tokenizer_names="[chartokenizer]" \ ++val_ds_meta.dpoprefval.manifest_path="/data/TTS/CML/manifests_with_codecs_ipa3/cml_tts_dataset_portuguese_v0.1/train_withAudioCodes_codec21KhzCausalDecoder_filtered_textcontextpairs_val_GRPO_ipa.json" \ ++val_ds_meta.dpoprefval.audio_dir="/" \ ++val_ds_meta.dpoprefval.feature_dir="/" \ ++val_ds_meta.dpoprefval.tokenizer_names="[chartokenizer]" \ ++model.grpo_beta=0.0 \ ++model.num_generations_per_item=12 \ +model.reference_free=true \ -model.decoder.p_dropout=0.0 \ -model.encoder.p_dropout=0.0 \ ++model.inference_cfg_prob=0.0 \ ++model.inference_cfg_scale=2.5 \ ++model.cer_reward_weight=0.33 \ ++model.ssim_reward_weight=0.33 \ ++model.pesq_reward_weight=0.33 \ ++model.use_pesq=true \ ++model.reward_asr_model="whisper" \ +model.cfg_unconditional_prob=0.0 \ ++model.inference_topk=2016 \ ++model.inference_temperature=0.8 \ ++model.use_kv_cache_during_online_po=true \ ++model.loss_type="grpo" \ ++model.scale_rewards=true \ ++model.max_decoder_steps=430 \ model.model_type="decoder_context_tts" \ -model.use_text_conditioning_encoder=true \ model.context_duration_min=5.0 \ model.context_duration_max=5.0 \ -model.codecmodel_path="/Data/Checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.decoder.p_dropout=0.0 \ +model.encoder.p_dropout=0.0 \ +model.local_transformer_type="autoregressive" \ +model.local_transformer_n_layers=1 \ +model.local_transformer_n_heads=1 \ +model.local_transformer_hidden_dim=256 \ +model.use_text_conditioning_encoder=true \ +model.codecmodel_path="/mountdir/checkpoints/21fps_causal_codecmodel.nemo" \ model.alignment_loss_scale=0.0 \ model.prior_scaling_factor=null \ -model.train_ds.dataloader_params.num_workers=0 \ -model.validation_ds.dataloader_params.num_workers=0 \ -exp_manager.checkpoint_callback_params.monitor="val_mean_reward" \ -exp_manager.checkpoint_callback_params.mode="max" \ -+trainer.use_distributed_sampler=False \ -+model.inference_cfg_prob=0.5 \ -+model.inference_cfg_scale=2.5 \ -batch_size=2 \ -model.optim.lr=1e-6 \ -trainer.devices=2 \ -trainer.log_every_n_steps=1 \ -trainer.val_check_interval=50 \ +~trainer.check_val_every_n_epoch \ ++trainer.val_check_interval=50 \ +trainer.log_every_n_steps=10 \ +model.optim.lr=1e-7 \ ~model.optim.sched \ -trainer.num_nodes=${SLURM_JOB_NUM_NODES} ; +exp_manager.checkpoint_callback_params.monitor="val_cer_gt" \ +exp_manager.checkpoint_callback_params.mode="min" \ +trainer.precision=32 \ ++trainer.gradient_clip_val=2.5 \ +trainer.num_nodes=${SLURM_JOB_NUM_NODES} ``` -Note that setting `+model.reference_free=true` makes the `grpo_beta` param effectively 0 since it does not use the KL regularization loss and saves memory. If using the `grpo_beta > 0` and `+model.reference_free=false`, make sure to set dropout params to 0, `model.decoder.p_dropout=0.0` and -`model.encoder.p_dropout=0.0` for training stabilization. Recommended learning rate is `model.optim.lr=1e-6` or lower. Setting `+model.inference_cfg_prob=0.5` means that for half of the generations will be generated using cfg, so that we optimize for our preferences in both cfg and non cfg inference modes. You may set `+model.inference_cfg_prob=0.0` if we only care about non-cfg inference. \ No newline at end of file diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index 0b8c8c890126..a7b88d8eb5cd 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -105,6 +105,12 @@ 'feature_dir' : '/Data/LibriTTS', 'tokenizer_names': ['chartokenizer'], }, + 'libri_unseen_test_shehzeen_shared_char_ipa': { + 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests_ipa/test_clean_withContextAudioPaths_ipa.json', + 'audio_dir' : '/Data/LibriTTS', + 'feature_dir' : '/Data/LibriTTS', + 'tokenizer_names': ['chartokenizer'], + }, 'grpo_valset': { 'manifest_path' : '/Data/DPOPairsInputDatav2/text_context_pairs_grpo_val_unseenspeakers.json', 'audio_dir' : '/', @@ -276,6 +282,14 @@ 'whisper_language': 'pt', 'load_cached_codes_if_available': False }, + 'portuguese_cml_shared_char_ipa': { + 'manifest_path' : '/Data/CML/manifests_with_codecs_ipa3/cml_tts_dataset_portuguese_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset_ipa.json', + 'audio_dir': '/Data/CML/cml_tts_dataset_portuguese_v0.1', + 'feature_dir': '/Data/CML/cml_tts_dataset_portuguese_v0.1', + 'tokenizer_names': ['chartokenizer'], + 'whisper_language': 'pt', + 'load_cached_codes_if_available': False + }, 'polish_cml_sep_char': { 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_polish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', 'audio_dir': '/Data/CML/cml_tts_dataset_polish_v0.1', @@ -289,6 +303,22 @@ 'audio_dir' : '/mnt/drive1/data/LibriTTS/', 'feature_dir' : None, }, + 'hindi_indic_shared_char': { + 'manifest_path' : '/Data/IndicDataset/manifests_ipa/hindi_100_test.json', + 'audio_dir': '/', + 'feature_dir': '/', + 'tokenizer_names': ['chartokenizer'], + 'whisper_language': 'hi', + 'load_cached_codes_if_available': False + }, + 'bengali_indic_shared_char': { + 'manifest_path' : '/Data/IndicDataset/manifests_ipa/bengali_100_test.json', + 'audio_dir': '/', + 'feature_dir': '/', + 'tokenizer_names': ['chartokenizer'], + 'whisper_language': 'bn', + 'load_cached_codes_if_available': False + }, 'an4_val_ci': { 'manifest_path' : '/home/TestData/an4_dataset/an4_val_context_v1.json', 'audio_dir' : '/', diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 720705bf4abc..237ad33e1c0f 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -73,7 +73,7 @@ def process_text(input_text): def transcribe_with_whisper(whisper_model, whisper_processor, audio_path, language, device): speech_array, sampling_rate = librosa.load(audio_path, sr=16000) # Set the language task (optional, improves performance for specific languages) - forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language) if language else None + forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features inputs = inputs.to(device) # Generate transcription @@ -186,7 +186,9 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo pred_text = "" gt_audio_text = "" - if 'normalized_text' in record: + if "original_text" in record: + gt_text = process_text(record['original_text']) + elif 'normalized_text' in record: gt_text = process_text(record['normalized_text']) else: gt_text = process_text(record['text']) diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh new file mode 100644 index 000000000000..a1d3ecd4b405 --- /dev/null +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh @@ -0,0 +1,90 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts.py \ + --config-name magpietts_multilingual_v1 \ + +mode="onlinepo_train" \ + ~model.text_tokenizers.multilingual_sentencepiece \ + +model.text_tokenizers.english_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.english_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.spanish_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.spanish_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.french_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.french_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.dutch_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.dutch_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.italian_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.italian_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.german_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.german_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.portugese_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.portugese_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.polish_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.polish_chartokenizer.pretrained_model="google/byt5-small" \ + +train_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_train_context_v1.json" \ + +train_ds_meta.an4.audio_dir="/" \ + +train_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +train_ds_meta.an4.feature_dir=null \ + +val_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val_context_v1.json" \ + +val_ds_meta.an4.audio_dir="/" \ + +val_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +val_ds_meta.an4.feature_dir=null \ + +init_from_ptl_ckpt="/home/TestData/tts/2506_SeenSpeaker/T5TTS--val_loss\=0.3125-epoch\=8.ckpt" \ + max_epochs=1 \ + batch_size=2 \ + +model.grpo_beta=0.0 \ + +model.num_generations_per_item=6 \ + +model.reference_free=true \ + +model.inference_cfg_prob=0.0 \ + +model.inference_cfg_scale=2.5 \ + +model.cer_reward_weight=0.5 \ + +model.ssim_reward_weight=0.5 \ + +model.reward_asr_model="whisper" \ + model.local_transformer_type="none" \ + model.cfg_unconditional_prob=0.0 \ + model.model_type="multi_encoder_context_tts" \ + model.transcript_decoder_layers="[0,2,4,6,8,10]" \ + model.context_decoder_layers="[1,3,5,7,9,11]" \ + model.context_duration_min=3.0 \ + model.context_duration_max=8.0 \ + model.decoder.p_dropout=0.0 \ + model.context_encoder.p_dropout=0.0 \ + model.encoder.p_dropout=0.0 \ + model.decoder.kernel_size=1 \ + model.decoder.xa_n_heads=1 \ + model.context_encoder.n_layers=6 \ + model.encoder.is_causal=false \ + model.use_text_conditioning_encoder=true \ + +model.forced_num_all_tokens_per_codebook=2048 \ + +model.forced_audio_eos_id=2047 \ + +model.forced_audio_bos_id=2046 \ + +model.forced_context_audio_eos_id=2045 \ + +model.forced_context_audio_bos_id=2044 \ + model.codecmodel_path="/home/TestData/tts/AudioCodec_21Hz_no_eliz_without_wavlm_disc.nemo" \ + model.alignment_loss_scale=0.0 \ + model.prior_scaling_factor=null \ + trainer.log_every_n_steps=10 \ + +model.inference_topk=2016 \ + model.optim.lr=2e-7 \ + ~model.optim.sched \ + +model.use_kv_cache_during_online_po=true \ + exp_manager.checkpoint_callback_params.monitor="val_cer_gt" \ + exp_manager.checkpoint_callback_params.mode="min" \ + trainer.precision=32 \ + trainer.devices="[0]" \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.strategy=auto \ + model.train_ds.dataloader_params.num_workers=0 \ + model.validation_ds.dataloader_params.num_workers=0 \ + ~trainer.check_val_every_n_epoch From c63b63a3e1026e7c8756edef179184134ab7de89 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 24 Jun 2025 11:06:57 -0700 Subject: [PATCH 053/113] inference bugfix: run all datasets (#13967) In `run_inference()` we were returning after just one dataset. Instead move the code that computes SSIM and CER averages outside the dataset loop so that all datasets get tested. We return an average CER and SSIM across datasets. Signed-off-by: Fejgin, Roy --- scripts/magpietts/infer_and_evaluate.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index c9fee4f982da..9eb0366ac97c 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -178,7 +178,10 @@ def run_inference( sv_model ) dataset_meta_info = evalset_config.dataset_meta_info + ssim_per_dataset = [] + cer_per_dataset = [] for dataset in datasets: + print(f"Evaluating dataset {dataset}") metrics_n_repeated = [] manifest_records = read_manifest(dataset_meta_info[dataset]['manifest_path']) for repeat_idx in range(num_repeats): @@ -343,13 +346,18 @@ def run_inference( measurements = [m['ssim_pred_context_avg'] for m in metrics_n_repeated] - ssim = np.mean(measurements) + ssim_current = np.mean(measurements) + ssim_per_dataset.append(ssim_current) measurements = [m['cer_cumulative'] for m in metrics_n_repeated] - cer = np.mean(measurements) - - if clean_up_disk: - shutil.rmtree(out_dir) - return cer, ssim + cer_current = np.mean(measurements) + cer_per_dataset.append(cer_current) + + # Average across datasets + ssim = np.mean(ssim_per_dataset) + cer = np.mean(cer_per_dataset) + if clean_up_disk: + shutil.rmtree(out_dir) + return cer, ssim def main(): parser = argparse.ArgumentParser(description='Experiment Evaluation') From d2f730a06fdbd8c8fc9dea959a3d0a884caf8783 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Fri, 27 Jun 2025 15:53:13 -0700 Subject: [PATCH 054/113] Interence fix: clean up old files (#14039) Delete leftover generated files (audio and codes) from previous runs if any exist. Leaving them there is risky since the same output directory gets reused when we evaluate the same checkpoint again. It's especially problematic if the previous evaluation was interrupted and left corrupted files. Signed-off-by: Fejgin, Roy --- scripts/magpietts/infer_and_evaluate.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 9eb0366ac97c..b42efd19ee3b 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -98,6 +98,16 @@ def update_ckpt(state_dict): new_state_dict[key] = state_dict[key] return new_state_dict + +def delete_old_generated_files(output_dir): + # Delete any leftover generated files from previous runs as these can confuse the evaluation + print(f"Deleting old generated files in: {output_dir} ...") + for f in glob.glob(f"{output_dir}/predicted_codes*.pt"): + os.remove(f) + for f in glob.glob(f"{output_dir}/predicted_audio*.wav"): + os.remove(f) + + def run_inference( hparams_file, checkpoint_file, @@ -189,6 +199,7 @@ def run_inference( audio_dir = os.path.join(eval_dir, "audio") pred_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") os.makedirs(pred_audio_dir, exist_ok=True) + delete_old_generated_files(pred_audio_dir) language = dataset_meta_info[dataset].get('whisper_language', 'en') dataset_meta_for_dl = copy.deepcopy(dataset_meta_info[dataset]) for key in ["whisper_language", "load_cached_codes_if_available"]: @@ -359,6 +370,7 @@ def run_inference( shutil.rmtree(out_dir) return cer, ssim + def main(): parser = argparse.ArgumentParser(description='Experiment Evaluation') parser.add_argument('--hparams_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_hparams.yaml,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_hparams.yaml") From 3d670b7078b6d6fdc1287a0bd7afbbe57d3b89ba Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Fri, 27 Jun 2025 18:41:31 -0700 Subject: [PATCH 055/113] inference: add option to include the experiment name in the output folder name (#14040) * Add an inference option to log the experiment name Signed-off-by: Fejgin, Roy * Add exp_name to checkpoint name rather than to metrics Signed-off-by: Fejgin, Roy * Clarify a help string Signed-off-by: Fejgin, Roy * Clarify help string Signed-off-by: Fejgin, Roy * Cleanup Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- scripts/magpietts/infer_and_evaluate.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index b42efd19ee3b..9b7ae676d917 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -134,6 +134,7 @@ def run_inference( maskgit_n_steps=3, legacy_codebooks=False, clean_up_disk=False, + log_exp_name=False ): # Load model if hparams_file is not None: @@ -170,8 +171,16 @@ def run_inference( model.cuda() model.eval() + if log_exp_name: + # the experiment name is the name of the directory two above the checkpoint path, + # since training produces directories of the form `exp_name/checkpoints/checkpoint_name.ckpt`. + exp_name = f"{os.path.basename(os.path.dirname(os.path.dirname(checkpoint_file)))}__" + else: + exp_name = "" + checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] - checkpoint_name = "{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_{}_{}_start{}_Estlayers{}_PrLayers{}_LT_{}_MGsteps{}_sv_{}".format( + checkpoint_name = "{}{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_LT_{}_MGsteps_{}_ST_{}_sched_{}".format( + exp_name, checkpoint_name, temperature, topk, @@ -187,6 +196,7 @@ def run_inference( maskgit_n_steps, sv_model ) + dataset_meta_info = evalset_config.dataset_meta_info ssim_per_dataset = [] cer_per_dataset = [] @@ -405,6 +415,7 @@ def main(): parser.add_argument('--clean_up_disk', action='store_true') parser.add_argument('--cer_target', type=float, default=None) parser.add_argument('--ssim_target', type=float, default=None) + parser.add_argument('--log_exp_name', action='store_true', help="Include the experiment name (derived from the checkpoint path) in the output folder name.") args = parser.parse_args() estimate_alignment_from_layers = None @@ -446,7 +457,8 @@ def main(): use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks, - clean_up_disk=args.clean_up_disk + clean_up_disk=args.clean_up_disk, + log_exp_name=args.log_exp_name ) return elif (args.nemo_file is not None): @@ -477,7 +489,8 @@ def main(): use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks, - clean_up_disk=args.clean_up_disk + clean_up_disk=args.clean_up_disk, + log_exp_name=args.log_exp_name ) else: BASE_EXP_DIR = args.base_exp_dir @@ -541,7 +554,8 @@ def main(): use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks, - clean_up_disk=args.clean_up_disk + clean_up_disk=args.clean_up_disk, + log_exp_name=args.log_exp_name ) if cer > float(args.cer_target): raise ValueError() From 41107281245f0f7e1a1f000845c55d53ba29beb4 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Fri, 27 Jun 2025 19:19:06 -0700 Subject: [PATCH 056/113] Bugfix in saving predicted codes for FCD calculation (#14051) The wrong dimension was being trimmed to valid the frame count. Signed-off-by: Fejgin, Roy --- scripts/magpietts/infer_and_evaluate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 9b7ae676d917..a9a7fa43496b 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -302,7 +302,8 @@ def run_inference( audio_path = os.path.join(pred_audio_dir, f"predicted_audio_{item_idx}.wav") sf.write(audio_path, predicted_audio_np, model.sample_rate) codes_path = os.path.join(pred_audio_dir, f"predicted_codes_{item_idx}.pt") - torch.save(predicted_codes[idx][:predicted_codes_lens[idx]], codes_path) + predicted_codes_current = predicted_codes[idx, :, :predicted_codes_lens[idx]] # C, T' + torch.save(predicted_codes_current, codes_path) codec_file_paths.append(codes_path) context_audio_path = manifest_records[item_idx].get('context_audio_filepath', None) target_audio_path = manifest_records[item_idx].get('audio_filepath', None) From ff72456c02fdef02d219d269905dcfc486fc9042 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Sat, 28 Jun 2025 01:50:34 -0700 Subject: [PATCH 057/113] [magpietts][eval][bugfix] fixed infer and eval scripts and supported loading wandb hparam config. (#13907) * [eval] bugfix infer and eval scripts and supports loading wandb config. * minor updates. * bugfix: checkpint_name can not be parsed when checkpoint_file is None * refactor: moved many variables irrelevant to num repeats outside of the for..loop of num_repeats; * fix default cer and ssim target. otherwise, float(None) errors out. * bugfix: fixed typos. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- scripts/magpietts/evaluate_generated_audio.py | 10 +- scripts/magpietts/infer_and_evaluate.py | 102 +++++++++++------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 237ad33e1c0f..10e872edb8d5 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -110,12 +110,10 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo device = "cuda" if language == "en": - if asr_model_name == "stt_en_conformer_transducer_large": - asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="stt_en_conformer_transducer_large") - elif asr_model_name == "nvidia/parakeet-ctc-0.6b": - asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b") - - # asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="nvidia/parakeet-tdt-1.1b") + if asr_model_name in ["nvidia/parakeet-tdt-1.1b", "nvidia/parakeet-ctc-0.6b", "stt_en_conformer_transducer_large"]: + asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=asr_model_name) + else: + raise ValueError(f"ASR model {asr_model_name} not supported") asr_model = asr_model.to(device) asr_model.eval() else: diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index a9a7fa43496b..8b10ce34a7ca 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -17,6 +17,7 @@ import json import os import shutil +import time import scripts.magpietts.evalset_config as evalset_config import scripts.magpietts.evaluate_generated_audio as evaluate_generated_audio @@ -134,14 +135,18 @@ def run_inference( maskgit_n_steps=3, legacy_codebooks=False, clean_up_disk=False, + hparams_file_from_wandb=False, log_exp_name=False ): # Load model - if hparams_file is not None: + if hparams_file is not None and checkpoint_file is not None: model_cfg = OmegaConf.load(hparams_file) if "cfg" in model_cfg: model_cfg = model_cfg.cfg + if hparams_file_from_wandb: + model_cfg = model_cfg.value + with open_dict(model_cfg): model_cfg, cfg_sample_rate = update_config(model_cfg, codecmodel_path, legacy_codebooks) @@ -162,7 +167,7 @@ def run_inference( model.use_kv_cache_for_inference = True checkpoint_name = nemo_file.split("/")[-1].split(".nemo")[0] else: - raise ValueError("Need a checkpoint") + raise ValueError("Need either a checkpoint and hparams file, or a nemo file.") if cfg_sample_rate is not None and cfg_sample_rate != model.sample_rate: raise ValueError("Sample rate in config and model do not match") @@ -204,24 +209,36 @@ def run_inference( print(f"Evaluating dataset {dataset}") metrics_n_repeated = [] manifest_records = read_manifest(dataset_meta_info[dataset]['manifest_path']) + language = dataset_meta_info[dataset].get('whisper_language', 'en') + dataset_meta_for_dl = copy.deepcopy(dataset_meta_info[dataset]) + for key in ["whisper_language", "load_cached_codes_if_available"]: + if key in dataset_meta_for_dl: + del dataset_meta_for_dl[key] + + dataset_meta = {dataset: dataset_meta_for_dl} + + eval_dir = os.path.join(out_dir, f"{checkpoint_name}_{dataset}") + audio_dir = os.path.join(eval_dir, "audio") + all_experiment_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") + os.makedirs(eval_dir, exist_ok=True) + + if not os.path.exists(all_experiment_csv): + with open(all_experiment_csv, "w") as f: + f.write( + "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,frechet_codec_distance\n" + ) + + context_duration_min = model.cfg.get('context_duration_min', 5.0) + context_duration_max = model.cfg.get('context_duration_max', 5.0) + if context_duration_min < 5.0 and context_duration_max > 5.0: + context_duration_min = 5.0 + context_duration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though. + for repeat_idx in range(num_repeats): - eval_dir = os.path.join(out_dir, "{}_{}".format(checkpoint_name, dataset)) - audio_dir = os.path.join(eval_dir, "audio") pred_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") os.makedirs(pred_audio_dir, exist_ok=True) delete_old_generated_files(pred_audio_dir) - language = dataset_meta_info[dataset].get('whisper_language', 'en') - dataset_meta_for_dl = copy.deepcopy(dataset_meta_info[dataset]) - for key in ["whisper_language", "load_cached_codes_if_available"]: - if key in dataset_meta_for_dl: - del dataset_meta_for_dl[key] - - dataset_meta = {dataset: dataset_meta_for_dl} - context_durration_min = model.cfg.get('context_duration_min', 5.0) - context_durration_max = model.cfg.get('context_duration_max', 5.0) - if context_durration_min < 5.0 and context_durration_max > 5.0: - context_durration_min = 5.0 - context_durration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though. + test_dataset = MagpieTTSDataset( dataset_meta=dataset_meta, sample_rate=model.sample_rate, @@ -242,10 +259,10 @@ def run_inference( load_16khz_audio=model.model_type == 'single_encoder_sv_tts', use_text_conditioning_tokenizer=model.use_text_conditioning_encoder, pad_context_text_to_max_duration=model.pad_context_text_to_max_duration, - context_duration_min=context_durration_min, - context_duration_max=context_durration_max, + context_duration_min=context_duration_min, + context_duration_max=context_duration_max, ) - assert len(test_dataset) == len(manifest_records), "Dataset length and manifest length should be the same. Dataset length: {}, Manifest length: {}".format(len(test_dataset), len(manifest_records)) + assert len(test_dataset) == len(manifest_records), f"Dataset length and manifest length should be the same. Dataset length: {len(test_dataset)}, Manifest length: {len(manifest_records)}" test_dataset.text_tokenizer = model.tokenizer test_dataset.text_conditioning_tokenizer = model.text_conditioning_tokenizer @@ -255,24 +272,23 @@ def run_inference( batch_size=batch_size, collate_fn=test_dataset.collate_fn, num_workers=2, - shuffle=False + shuffle=False, ) item_idx = 0 all_rtf_metrics = [] codec_file_paths = [] for bidx, batch in enumerate(test_data_loader): - print("Processing batch {} out of {} of dataset {}".format(bidx, len(test_data_loader), dataset)) - batch_cuda ={} + print(f"Processing batch {bidx} out of {len(test_data_loader)} of dataset {dataset}") + batch_cuda = {} for key in batch: if isinstance(batch[key], torch.Tensor): batch_cuda[key] = batch[key].cuda() else: batch_cuda[key] = batch[key] - import time st = time.time() - predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, rtf_metrics, cross_attention_maps, _ = model.infer_batch( + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, rtf_metrics, cross_attention_maps, _ = model.infer_batch( batch_cuda, max_decoder_steps=440, temperature=temperature, @@ -287,7 +303,7 @@ def run_inference( apply_prior_to_layers=apply_prior_to_layers, start_prior_after_n_audio_steps=start_prior_after_n_audio_steps, use_local_transformer_for_inference=use_local_transformer, - maskgit_n_steps=maskgit_n_steps + maskgit_n_steps=maskgit_n_steps, ) all_rtf_metrics.append(rtf_metrics) @@ -341,13 +357,10 @@ def run_inference( with open(os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}.json"), "w") as f: json.dump(mean_rtf_metrics, f, indent=4) - all_experiment_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") - if not os.path.exists(all_experiment_csv): - with open(all_experiment_csv, "w") as f: - f.write("checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,frechet_codec_distance\n") with open(all_experiment_csv, "a") as f: f.write(f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']},{metrics['frechet_codec_distance']}\n") print(f"Wrote metrics for {checkpoint_name} and {dataset} to {all_experiment_csv}") + # Clean up temporary codec files for codes_file in codec_file_paths: os.remove(codes_file) @@ -385,6 +398,7 @@ def run_inference( def main(): parser = argparse.ArgumentParser(description='Experiment Evaluation') parser.add_argument('--hparams_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_hparams.yaml,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_hparams.yaml") + parser.add_argument('--hparams_file_from_wandb', action='store_true') parser.add_argument('--checkpoint_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_epoch302.ckpt,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_epoch305.ckpt") parser.add_argument('--nemo_file', type=str, default=None) parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo") @@ -414,8 +428,8 @@ def main(): parser.add_argument('--confidence_level', type=float, default=0.95) parser.add_argument('--legacy_codebooks', action='store_true') parser.add_argument('--clean_up_disk', action='store_true') - parser.add_argument('--cer_target', type=float, default=None) - parser.add_argument('--ssim_target', type=float, default=None) + parser.add_argument('--cer_target', type=float, default=1.0) + parser.add_argument('--ssim_target', type=float, default=0.) parser.add_argument('--log_exp_name', action='store_true', help="Include the experiment name (derived from the checkpoint path) in the output folder name.") args = parser.parse_args() @@ -426,6 +440,7 @@ def main(): if args.apply_prior_to_layers is not None: apply_prior_to_layers = [int(l.strip()) for l in args.apply_prior_to_layers.split(",")] + # Mode 1: Run inference from provided hparams and checkpoint files if (args.hparams_files is not None) and (args.checkpoint_files is not None) and (args.hparams_files != "null") and (args.checkpoint_files != "null"): hparam_files = args.hparams_files.split(",") checkpoint_files = args.checkpoint_files.split(",") @@ -459,16 +474,17 @@ def main(): maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks, clean_up_disk=args.clean_up_disk, + hparams_file_from_wandb=args.hparams_file_from_wandb, log_exp_name=args.log_exp_name ) return - elif (args.nemo_file is not None): - nemo_file = args.nemo_file - print("Running inference for nemo file: ", nemo_file) + # Mode 2: Run inference from a .nemo file + elif args.nemo_file: + print(f"Running inference for nemo file: {args.nemo_file}") cer, ssim = run_inference( hparams_file=None, checkpoint_file=None, - nemo_file=nemo_file, + nemo_file=args.nemo_file, datasets=args.datasets.split(","), out_dir=args.out_dir, temperature=args.temperature, @@ -491,13 +507,15 @@ def main(): maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks, clean_up_disk=args.clean_up_disk, + hparams_file_from_wandb=args.hparams_file_from_wandb, log_exp_name=args.log_exp_name ) - else: + # Mode 3: Discover and run experiments from a base directory + # Mount DRACO_EXP_DIR to BASE_EXP_DIR as follows: + # sshfs -o allow_other pneekhara@draco-oci-dc-02.draco-oci-iad.nvidia.com:/lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/experiments/NewT5AllFixedFresh /datap/misc/dracomount/ + elif args.base_exp_dir: BASE_EXP_DIR = args.base_exp_dir DRACO_EXP_DIR = args.draco_exp_dir - # Mount DRACO_EXP_DIR to BASE_EXP_DIR as follows: - # sshfs -o allow_other pneekhara@draco-oci-dc-02.draco-oci-iad.nvidia.com:/lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/experiments/NewT5AllFixedFresh /datap/misc/dracomount/ if args.exp_names is None: exp_names = os.listdir(BASE_EXP_DIR) else: @@ -556,8 +574,16 @@ def main(): maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks, clean_up_disk=args.clean_up_disk, + hparams_file_from_wandb=args.hparams_file_from_wandb, log_exp_name=args.log_exp_name ) + else: + parser.error( + "You must provide a model to run. Please specify either:\n" + "1. --hparams_files and --checkpoint_files\n" + "2. --nemo_file\n" + "3. --base_exp_dir to discover experiments" + ) if cer > float(args.cer_target): raise ValueError() if ssim < float(args.ssim_target): From dbc6e7830c2ade529eecc904d0444e87637d027b Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 1 Jul 2025 10:01:54 -0700 Subject: [PATCH 058/113] Inference: pad short signals before embedding them (#14055) The speaker embedding model crashes on very short signals. So we zero-pad the end of the signal if it's less than 0.5 seconds long before running it through the speaker embedding model. Signed-off-by: Fejgin, Roy --- scripts/magpietts/evaluate_generated_audio.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 10e872edb8d5..296de8baf060 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -17,6 +17,7 @@ import pprint import string +import numpy as np import torch import nemo.collections.asr as nemo_asr @@ -85,9 +86,27 @@ def transcribe_with_whisper(whisper_model, whisper_processor, audio_path, langua result = transcription[0] return result +def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_seconds: float) -> np.ndarray: + """ + Pad audio to make it at least `min_seconds` long by adding silence at the end if needed. + """ + if audio_np.ndim != 1: + raise ValueError("Audio array must be 1D") + + n_samples = len(audio_np) + min_samples = round(min_seconds * sampling_rate) + + if n_samples < min_samples: + print(f"Padding audio from {n_samples/sampling_rate} seconds to {min_samples/sampling_rate} seconds") + padding_needed = min_samples - n_samples + audio_np = np.pad(audio_np, (0, padding_needed), mode='constant', constant_values=0) + return audio_np + + def extract_embedding(model, extractor, audio_path, device, sv_model_type): speech_array, sampling_rate = librosa.load(audio_path, sr=16000) - + # pad to 0.5 seconds as the extractor may not be able to handle very short signals + speech_array = pad_audio_to_min_length(speech_array, int(sampling_rate), min_seconds=0.5) if sv_model_type == "wavlm": inputs = extractor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_values.to(device) with torch.no_grad(): From e2797eaff26d682f181774c7a3e6f41e26d8a69c Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Tue, 1 Jul 2025 12:56:27 -0700 Subject: [PATCH 059/113] [magpietts] decoder CE model type (#13727) * decoder CE model type Signed-off-by: Paarth Neekhara * infer and evaluate backward compatibility Signed-off-by: Paarth Neekhara * cond adapter for diff x attn size Signed-off-by: Paarth Neekhara * update alignment encoder Signed-off-by: Paarth Neekhara * clean up branch Signed-off-by: Paarth Neekhara * comment Signed-off-by: Paarth Neekhara * fix comment Signed-off-by: Paarth Neekhara * comments for different architectures Signed-off-by: Paarth Neekhara * reverse speaker projection change Signed-off-by: Paarth Neekhara * model type change Signed-off-by: Paarth Neekhara --------- Signed-off-by: Paarth Neekhara --- nemo/collections/tts/models/magpietts.py | 30 ++++++++++++++++++++---- scripts/magpietts/infer_and_evaluate.py | 5 ++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index fbfdfb2a5013..bb7e241c66cb 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -128,7 +128,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.model_type = cfg.get('model_type', None) - self.pad_context_text_to_max_duration = self.model_type == 'decoder_context_tts' + self.pad_context_text_to_max_duration = self.model_type in ['decoder_context_tts', 'decoder_ce'] self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False) super().__init__(cfg=cfg, trainer=trainer) @@ -173,6 +173,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) + self.final_proj = nn.Linear(cfg.decoder.d_model, self.num_audio_codebooks * self.num_all_tokens_per_codebook) self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower()) @@ -208,6 +209,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): ) if self.model_type == 'single_encoder_sv_tts': + # Context audio goes through Titanet to get speaker embedding + # Speaker embedding is added to the transcript encoder output self._speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( model_name='titanet_large' ) @@ -217,6 +220,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): idx for idx in range(self.decoder.n_layers) ] # All layers are used for text elif self.model_type == 'multi_encoder_context_tts': + # Transcript and context audio/text go to different encoders. + # Output of the encoders goes to the decoder through the cross-attention layers self.transcript_decoder_layers = cfg.get('transcript_decoder_layers', [3, 4, 5, 6, 7, 8]) self.context_decoder_layers = cfg.get( 'context_decoder_layers', [0, 1, 2, 9, 10, 11] @@ -229,10 +234,20 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.multi_encoder_mapping = multi_encoder_mapping self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder)) elif self.model_type == 'decoder_context_tts': + # Context audio/text goes directly to the decoder (before the target audio codes) self.transcript_decoder_layers = [ idx for idx in range(self.decoder.n_layers) ] # All layers are used for text + elif self.model_type == 'decoder_ce': + # Similar to decoder_context_tts, but we use context encoder + # Decoder gets output from context encoder instead of raw context tokens embeddings + self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder)) + self.transcript_decoder_layers = [ + idx for idx in range(cfg.decoder.n_layers) + ] # All layers are used for text + elif self.model_type == 'decoder_pretrain_synthesizer': + # This is for pretraining the decoder only on audio data using next frame prediction loss assert cfg.alignment_loss_scale == 0.0, "Alignment loss is not supported for decoder pretrain synthesizer" else: raise ValueError(f"Unsupported model type {self.model_type}") @@ -888,7 +903,7 @@ def prepare_context_tensors(self, batch): text_lens = None # self.model_type must be one of - # [single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts, decoder_pretrain_synthesizer] + # [single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts, decoder_ce, decoder_pretrain_synthesizer] if self.model_type != 'decoder_pretrain_synthesizer': text = batch['text'] text_lens = batch['text_lens'] @@ -907,7 +922,7 @@ def prepare_context_tensors(self, batch): cond_mask = text_mask multi_encoder_mapping = None attn_prior = _attn_prior - elif self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts']: + elif self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts', 'decoder_ce']: if 'context_audio_codes' in batch: context_audio_codes = batch['context_audio_codes'] context_audio_codes_lens = batch['context_audio_codes_lens'] @@ -961,9 +976,14 @@ def prepare_context_tensors(self, batch): multi_encoder_mapping = self.multi_encoder_mapping attn_prior = [_attn_prior, None] - elif self.model_type == 'decoder_context_tts': + elif self.model_type in ['decoder_context_tts', 'decoder_ce']: dec_context_size = context_mask.size(1) - context_embeddings = context_input_embedded + if self.model_type == 'decoder_context_tts': + context_embeddings = context_input_embedded + elif self.model_type == 'decoder_ce': + context_embeddings = self.context_encoder( + context_input_embedded, context_mask, cond=None, cond_mask=None + )['output'] attn_prior = _attn_prior if attn_prior is not None: # B, audio_timesteps, text_timesteps diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 8b10ce34a7ca..66ffb96919ec 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -63,6 +63,11 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False): if hasattr(model_cfg, 'decoder') and hasattr(model_cfg.decoder, 'prior_eps'): # Added to prevent crash after removing arg from transformer_2501.py in https://github.com/blisc/NeMo/pull/56 del model_cfg.decoder.prior_eps + if hasattr(model_cfg, 'use_local_transformer') and model_cfg.use_local_transformer: + # For older checkpoints trained with a different parameter name + model_cfg.local_transformer_type = "autoregressive" + del model_cfg.use_local_transformer + if legacy_codebooks: # Added to address backward compatibility arising from # https://github.com/blisc/NeMo/pull/64 From 5ee5698492eec4d133e68260a09ec1eb1c3fbb05 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 1 Jul 2025 13:05:24 -0700 Subject: [PATCH 060/113] EOS detection: check all codebooks (#14038) We previously only examined at the first codebook to check for EOS. This carried the risk of EOS tokens being returned to the use if they occurred only in codebooks other than codebook zero and would therefore not be detected. Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index bb7e241c66cb..1228d34581db 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -1738,9 +1738,9 @@ def infer_batch( for item_idx in range(all_codes_next_argmax.size(0)): if item_idx not in end_indices: - pred_token = all_codes_next_argmax[item_idx][0].item() - pred_token_multinomial = audio_codes_next[item_idx][0].item() - if (pred_token == self.audio_eos_id) or (pred_token_multinomial == self.audio_eos_id): + eos_in_pred_tokens_argmax = (all_codes_next_argmax[item_idx] == self.audio_eos_id).any().item() + eos_in_pred_tokens_multinomial = (audio_codes_next[item_idx] == self.audio_eos_id).any().item() + if eos_in_pred_tokens_argmax or eos_in_pred_tokens_multinomial: print("End detected for item {} at timestep {}".format(item_idx, idx)) end_indices[item_idx] = idx From 4ff2e0e402ba9c90b0b97ddc4992076fbecddfb1 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Thu, 3 Jul 2025 17:02:45 -0700 Subject: [PATCH 061/113] [eval] added the support of evaluating multiple .nemo checkpoints just the same as .ckpt checkpoints. (#14082) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- scripts/magpietts/infer_and_evaluate.py | 67 +++++++++++++------------ 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 66ffb96919ec..89c8c9e6684a 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -405,8 +405,8 @@ def main(): parser.add_argument('--hparams_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_hparams.yaml,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_hparams.yaml") parser.add_argument('--hparams_file_from_wandb', action='store_true') parser.add_argument('--checkpoint_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_epoch302.ckpt,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_epoch305.ckpt") - parser.add_argument('--nemo_file', type=str, default=None) - parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo") + parser.add_argument('--nemo_files', type=str, default=None) + parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo", help="Path to codec model (only used when --compute_fcd is specified)") parser.add_argument('--datasets', type=str, default="libri_unseen_test_12.5") parser.add_argument('--base_exp_dir', type=str, default="/datap/misc/eosmountedresson/") parser.add_argument('--draco_exp_dir', type=str, default="/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/gitrepos/experiments/NewT5TTS_FixedPosEmb/AllKernselSize3/EdressonCodecExperiments/") @@ -484,37 +484,38 @@ def main(): ) return # Mode 2: Run inference from a .nemo file - elif args.nemo_file: - print(f"Running inference for nemo file: {args.nemo_file}") - cer, ssim = run_inference( - hparams_file=None, - checkpoint_file=None, - nemo_file=args.nemo_file, - datasets=args.datasets.split(","), - out_dir=args.out_dir, - temperature=args.temperature, - topk=args.topk, - codecmodel_path=args.codecmodel_path, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - batch_size=args.batch_size, - sv_model=args.sv_model, - asr_model_name=args.asr_model_name, - num_repeats=args.num_repeats, - apply_attention_prior=args.apply_attention_prior, - attention_prior_epsilon=args.attention_prior_epsilon, - attention_prior_lookahead_window=args.attention_prior_lookahead_window, - estimate_alignment_from_layers=estimate_alignment_from_layers, - apply_prior_to_layers=apply_prior_to_layers, - start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, - confidence_level=args.confidence_level, - use_local_transformer=args.use_local_transformer, - maskgit_n_steps=args.maskgit_n_steps, - legacy_codebooks=args.legacy_codebooks, - clean_up_disk=args.clean_up_disk, - hparams_file_from_wandb=args.hparams_file_from_wandb, - log_exp_name=args.log_exp_name - ) + elif args.nemo_files: + print(f"Running inference for nemo file: {args.nemo_files}") + for nemo_file in args.nemo_files.split(","): + cer, ssim = run_inference( + hparams_file=None, + checkpoint_file=None, + nemo_file=nemo_file, + datasets=args.datasets.split(","), + out_dir=args.out_dir, + temperature=args.temperature, + topk=args.topk, + codecmodel_path=args.codecmodel_path, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + batch_size=args.batch_size, + sv_model=args.sv_model, + asr_model_name=args.asr_model_name, + num_repeats=args.num_repeats, + apply_attention_prior=args.apply_attention_prior, + attention_prior_epsilon=args.attention_prior_epsilon, + attention_prior_lookahead_window=args.attention_prior_lookahead_window, + estimate_alignment_from_layers=estimate_alignment_from_layers, + apply_prior_to_layers=apply_prior_to_layers, + start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, + confidence_level=args.confidence_level, + use_local_transformer=args.use_local_transformer, + maskgit_n_steps=args.maskgit_n_steps, + legacy_codebooks=args.legacy_codebooks, + clean_up_disk=args.clean_up_disk, + hparams_file_from_wandb=args.hparams_file_from_wandb, + log_exp_name=args.log_exp_name + ) # Mode 3: Discover and run experiments from a base directory # Mount DRACO_EXP_DIR to BASE_EXP_DIR as follows: # sshfs -o allow_other pneekhara@draco-oci-dc-02.draco-oci-iad.nvidia.com:/lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/experiments/NewT5AllFixedFresh /datap/misc/dracomount/ From 41379eef033173bc0cef9dd809f12a1ac6fffa5d Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Thu, 3 Jul 2025 17:02:59 -0700 Subject: [PATCH 062/113] make feature_dir default to None to avoid adding feature_dir in (#14080) the eval config. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- nemo/collections/tts/data/text_to_speech_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 47083c2b538d..ad2f9297e470 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -44,7 +44,7 @@ class DatasetMeta: manifest_path: Path audio_dir: Path - feature_dir: Path + feature_dir: Path = None sample_weight: float = 1.0 tokenizer_names: List[str] = None @@ -195,8 +195,8 @@ def _preprocess_manifest( sample = DatasetSample( dataset_name=dataset_name, manifest_entry=entry, - audio_dir=Path(dataset.audio_dir), - feature_dir=Path(dataset.feature_dir) if dataset.feature_dir is not None else None, + audio_dir=dataset.audio_dir, + feature_dir=dataset.feature_dir, text=text, speaker=speaker, speaker_index=speaker_index, From 113e0e2783a1aadecbd5b19a517bf6600ae8a13a Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Mon, 7 Jul 2025 09:41:26 -0700 Subject: [PATCH 063/113] [eval] make compute_fcd optional for now. (#14081) * [eval] make compute_fcd optional for now. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * make fcd computation by default, but instead, add --disable-fcd as an option to disable fcd computation Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- scripts/magpietts/infer_and_evaluate.py | 50 ++++++++++++++++++------- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 89c8c9e6684a..4604083edd7c 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -141,7 +141,8 @@ def run_inference( legacy_codebooks=False, clean_up_disk=False, hparams_file_from_wandb=False, - log_exp_name=False + log_exp_name=False, + compute_fcd=False, ): # Load model if hparams_file is not None and checkpoint_file is not None: @@ -229,9 +230,11 @@ def run_inference( if not os.path.exists(all_experiment_csv): with open(all_experiment_csv, "w") as f: - f.write( - "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,frechet_codec_distance\n" - ) + header = "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative" + if compute_fcd: + header += ",frechet_codec_distance" + header += "\n" + f.write(header) context_duration_min = model.cfg.get('context_duration_min', 5.0) context_duration_max = model.cfg.get('context_duration_max', 5.0) @@ -349,7 +352,7 @@ def run_inference( language=language, sv_model_type=sv_model, asr_model_name=asr_model_name, - codecmodel_path=codecmodel_path + codecmodel_path=codecmodel_path if compute_fcd else None ) metrics_n_repeated.append(metrics) with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: @@ -363,7 +366,11 @@ def run_inference( json.dump(mean_rtf_metrics, f, indent=4) with open(all_experiment_csv, "a") as f: - f.write(f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']},{metrics['frechet_codec_distance']}\n") + data = f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']}" + if compute_fcd: + data += f",{metrics['frechet_codec_distance']}" + data += "\n" + f.write(data) print(f"Wrote metrics for {checkpoint_name} and {dataset} to {all_experiment_csv}") # Clean up temporary codec files @@ -373,15 +380,25 @@ def run_inference( metric_keys = ['cer_filewise_avg', 'wer_filewise_avg', 'cer_cumulative', 'wer_cumulative', 'ssim_pred_gt_avg', 'ssim_pred_context_avg', 'ssim_gt_context_avg', 'ssim_pred_gt_avg_alternate', 'ssim_pred_context_avg_alternate', 'ssim_gt_context_avg_alternate', - 'cer_gt_audio_cumulative', 'wer_gt_audio_cumulative', 'frechet_codec_distance' + 'cer_gt_audio_cumulative', 'wer_gt_audio_cumulative' ] + if compute_fcd: + metric_keys.append('frechet_codec_distance') metrics_mean_ci = compute_mean_and_confidence_interval(metrics_n_repeated, metric_keys, confidence=confidence_level) all_experiment_csv_with_ci = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") if not os.path.exists(all_experiment_csv_with_ci): with open(all_experiment_csv_with_ci, "w") as f: - f.write("checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,frechet_codec_distance\n") + header = "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative" + if compute_fcd: + header += ",frechet_codec_distance" + header += "\n" + f.write(header) with open(all_experiment_csv_with_ci, "a") as f: - f.write(f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']},{metrics_mean_ci['frechet_codec_distance']}\n") + data = f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']}" + if compute_fcd: + data += f",{metrics_mean_ci['frechet_codec_distance']}" + data += "\n" + f.write(data) print(f"Wrote metrics with CI for {checkpoint_name} and {dataset} to {all_experiment_csv_with_ci}") @@ -406,7 +423,7 @@ def main(): parser.add_argument('--hparams_file_from_wandb', action='store_true') parser.add_argument('--checkpoint_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_epoch302.ckpt,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_epoch305.ckpt") parser.add_argument('--nemo_files', type=str, default=None) - parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo", help="Path to codec model (only used when --compute_fcd is specified)") + parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo", help="Path to codec model (used for FCD computation unless --disable_fcd is specified)") parser.add_argument('--datasets', type=str, default="libri_unseen_test_12.5") parser.add_argument('--base_exp_dir', type=str, default="/datap/misc/eosmountedresson/") parser.add_argument('--draco_exp_dir', type=str, default="/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/gitrepos/experiments/NewT5TTS_FixedPosEmb/AllKernselSize3/EdressonCodecExperiments/") @@ -436,8 +453,12 @@ def main(): parser.add_argument('--cer_target', type=float, default=1.0) parser.add_argument('--ssim_target', type=float, default=0.) parser.add_argument('--log_exp_name', action='store_true', help="Include the experiment name (derived from the checkpoint path) in the output folder name.") + parser.add_argument('--disable_fcd', action='store_true', help="Disable Frechet Codec Distance computation") args = parser.parse_args() + # FCD computation is enabled by default, disabled only when --disable_fcd is specified + compute_fcd = not args.disable_fcd + estimate_alignment_from_layers = None if args.estimate_alignment_from_layers is not None: estimate_alignment_from_layers = [int(l.strip()) for l in args.estimate_alignment_from_layers.split(",")] @@ -480,7 +501,8 @@ def main(): legacy_codebooks=args.legacy_codebooks, clean_up_disk=args.clean_up_disk, hparams_file_from_wandb=args.hparams_file_from_wandb, - log_exp_name=args.log_exp_name + log_exp_name=args.log_exp_name, + compute_fcd=compute_fcd ) return # Mode 2: Run inference from a .nemo file @@ -514,7 +536,8 @@ def main(): legacy_codebooks=args.legacy_codebooks, clean_up_disk=args.clean_up_disk, hparams_file_from_wandb=args.hparams_file_from_wandb, - log_exp_name=args.log_exp_name + log_exp_name=args.log_exp_name, + compute_fcd=compute_fcd ) # Mode 3: Discover and run experiments from a base directory # Mount DRACO_EXP_DIR to BASE_EXP_DIR as follows: @@ -581,7 +604,8 @@ def main(): legacy_codebooks=args.legacy_codebooks, clean_up_disk=args.clean_up_disk, hparams_file_from_wandb=args.hparams_file_from_wandb, - log_exp_name=args.log_exp_name + log_exp_name=args.log_exp_name, + compute_fcd=compute_fcd ) else: parser.error( From db68d32d994cc495efeec922c2d64e19577df266 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 8 Jul 2025 07:04:34 -0700 Subject: [PATCH 064/113] FCD Metric: if provided with codes of unexpected shape, log warning but do not crash (#14035) Signed-off-by: Fejgin, Roy --- nemo/collections/tts/modules/fcd_metric.py | 4 ++++ tests/collections/tts/modules/test_fcd_metric.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/nemo/collections/tts/modules/fcd_metric.py b/nemo/collections/tts/modules/fcd_metric.py index 9c2c4d31b540..c0ca4c406500 100644 --- a/nemo/collections/tts/modules/fcd_metric.py +++ b/nemo/collections/tts/modules/fcd_metric.py @@ -172,6 +172,10 @@ def update(self, codes: Tensor, codes_len: Tensor, is_real: bool): logging.warning(f"\nFCD metric received an empty batch of codes - skipping update\n") return + if codes.shape[1] != self.model.codec.num_codebooks: + logging.warning(f"\nFCD metric received a batch of codes of shape {codes.shape}, but the model has {self.model.codec.num_codebooks} codebooks - skipping update\n") + return + # Dequantize the codes to a continuous representation embeddings = self.model.codes_to_embedding( codes, codes_len diff --git a/tests/collections/tts/modules/test_fcd_metric.py b/tests/collections/tts/modules/test_fcd_metric.py index bba65cf7b821..dc87548ddd8f 100644 --- a/tests/collections/tts/modules/test_fcd_metric.py +++ b/tests/collections/tts/modules/test_fcd_metric.py @@ -110,3 +110,14 @@ def test_empty_codes_update(self, metric, device): codes_len = T * torch.ones(B, device=device) # if it crashes PyTest will report it metric.update(codes, codes_len, is_real=True) + + @pytest.mark.unit + def test_codebooks_mismatch_update(self, metric, device, codec): + """Test that the FCD metric doesn't crash when provided with incorrect number ofcodebooks.""" + B = 2 + C = codec.num_codebooks - 1 # intentionally missing one codebook + T = 10 + codes = torch.ones(B, C, T, device=device) + codes_len = T * torch.ones(B, device=device, dtype=torch.long) + # if it crashes PyTest will report it + metric.update(codes, codes_len, is_real=True) From efa9c1345db3be1ed62fbd56ea9ea7e634266eaf Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 8 Jul 2025 07:05:42 -0700 Subject: [PATCH 065/113] [eval] Suppress TitaNet messages during initialization (#14101) * Suppress massive titanet messages during init Signed-off-by: Fejgin, Roy * Comments Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- scripts/magpietts/evaluate_generated_audio.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 296de8baf060..e28ca5002c8f 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -16,6 +16,8 @@ import os import pprint import string +import logging +from contextlib import contextmanager import numpy as np import torch @@ -102,6 +104,23 @@ def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_second audio_np = np.pad(audio_np, (0, padding_needed), mode='constant', constant_values=0) return audio_np +@contextmanager +def nemo_log_level(level): + """ + A context manager that temporarily sets the logging level for the NeMo logger + and restores the original level when the context manager is exited. + + Args: + level (int): The logging level to set. + """ + logger = logging.getLogger("nemo_logger") + original_level = logger.level + logger.setLevel(level) + try: + yield + finally: + # restore the original level when the context manager is exited (even if an exception was raised) + logger.setLevel(original_level) def extract_embedding(model, extractor, audio_path, device, sv_model_type): speech_array, sampling_rate = librosa.load(audio_path, sr=16000) @@ -149,8 +168,10 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') speaker_verification_model = speaker_verification_model.to(device) speaker_verification_model.eval() - - speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_small') + with nemo_log_level(logging.ERROR): + # The model `titanet_small` prints thousands of lines during initialization, so suppress logs temporarily + print("Loading `titanet_small` model...") + speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_small') speaker_verification_model_alternate = speaker_verification_model_alternate.to(device) speaker_verification_model_alternate.eval() From 04078a92f7e442e64e4fdf7b897b188c6702930c Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 18 Jul 2025 13:17:16 -0700 Subject: [PATCH 066/113] [lhotse][sampler] added lhotse sampler that filters out records that has validation status not pass. (#14280) * added lhotse sampler that filters out records that have custom field validation_status not pass. add tests for filters. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * run pre-compile to format the test file Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- .../common/data/lhotse/dataloader.py | 6 + .../common/data/lhotse/sampling.py | 14 ++ .../common/test_lhotse_tts_filters.py | 143 ++++++++++++++++++ 3 files changed, 163 insertions(+) create mode 100644 tests/collections/common/test_lhotse_tts_filters.py diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 20beff3c3817..d0c9c60bf17c 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -45,6 +45,7 @@ ) from nemo.collections.common.data.lhotse.sampling import ( BucketingFilter, + ValidationStatusFilter, CERFilter, ContextSpeakerSimilarityFilter, DurationFilter, @@ -136,6 +137,9 @@ class LhotseDataLoadingConfig: max_cer: float | None = float("inf") min_context_speaker_similarity: float | None = -1 + # 2.4 Filters on validation status. If the validation status is not "pass", the cut will be filtered out. + keep: str = "pass" + # 3. Supported existing NeMo options. shuffle: bool = False sample_rate: int = 16000 @@ -546,6 +550,8 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No TokenCountFilter(config.min_tokens, config.max_tokens, measure_total_length=config.measure_total_length) ) + # validation status filtering + cuts = cuts.filter(ValidationStatusFilter(config.keep)) # CER filtering, same as native NeMo dataloaders. cuts = cuts.filter(CERFilter(config.max_cer)) # Context speaker similarity filtering, same as native NeMo dataloaders. diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 48944fdb89b2..881bbfe3190d 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -262,6 +262,20 @@ def __call__(self, example) -> bool: else: return True # does not apply to text etc. +class ValidationStatusFilter: + """ + Callable, returns ``True`` if a cut's validation status is equal to keep and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + def __init__(self, keep: str = "pass") -> None: + self.keep = keep + + def __call__(self, example) -> bool: + if isinstance(example, Cut) and example.has_custom("validation_status") and example.validation_status != self.keep: + return False + else: + return True + class CERFilter: """ Callable, returns ``True`` if a cut's CER is less than max_cer and ``False`` otherwise. diff --git a/tests/collections/common/test_lhotse_tts_filters.py b/tests/collections/common/test_lhotse_tts_filters.py new file mode 100644 index 000000000000..877b0886ab97 --- /dev/null +++ b/tests/collections/common/test_lhotse_tts_filters.py @@ -0,0 +1,143 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from lhotse import SupervisionSegment +from lhotse.array import Array, TemporalArray +from lhotse.audio import AudioSource, Recording +from lhotse.cut import MonoCut + +from nemo.collections.common.data.lhotse.sampling import ( + CERFilter, + ContextSpeakerSimilarityFilter, + ValidationStatusFilter, +) + + +@pytest.fixture +def cut_example(): + cut = MonoCut( + id='cut-rec-Zdud2gXLTXY-238.16-6.88_repeat0', + start=238.16, + duration=6.88, + channel=0, + supervisions=[ + SupervisionSegment( + id='sup-rec-Zdud2gXLTXY', + recording_id='rec-Zdud2gXLTXY', + start=238.16, + duration=6.88, + channel=0, + text='and in like manner, as do other parts in which there appears to exist an adaptation to an end.', + language='en', + speaker='| Language:en Dataset:nvyt2505 Speaker:Zdud2gXLTXY_SPEAKER_02 |', + gender=None, + custom={ + 'cer': 0.03, + 'bandwidth': 10875, + 'stoi_squim': 0.921, + 'sisdr_squim': 15.17, + 'pesq_squim': 1.845, + 'dataset_id': '5a6446c5-6114-4380-b875-9de17fda2b8d', + 'dataset_version': '2024_11_07_131440', + 'dataset_name': 'yt_mixed', + 'context_speaker_similarity': 0.9172529578208923, + 'context_audio_offset': 7001.95659375, + 'context_audio_duration': 14.64, + 'context_audio_text': 'Uat gives an excellent illustration of the effects of a course of selection, which may be considered as unconscious, insofar that the breeders could never have expected, or even wished, to produce the result which ensued,', + 'context_recording_id': 'rec-Zdud2gXLTXY', + }, + alignment=None, + ) + ], + features=None, + recording=Recording( + id='rec-Zdud2gXLTXY', + sources=[AudioSource(type='file', channels=[0], source='/audio/Zdud2gXLTXY.wav')], + sampling_rate=22050, + num_samples=952064173, + duration=43177.51351473923, + channel_ids=[0], + transforms=None, + ), + custom={ + 'validation_status': 'pass', + 'target_audio': Recording( + id='cut-rec-Zdud2gXLTXY-238.16-6.88', + sources=[AudioSource(type='memory', channels=[0], source='')], + sampling_rate=22050, + num_samples=151704, + duration=6.88, + channel_ids=[0], + transforms=None, + ), + 'context_audio': Recording( + id='context_cut-rec-Zdud2gXLTXY-7001.96-14.64', + sources=[AudioSource(type='memory', channels=[0], source='')], + sampling_rate=22050, + num_samples=322812, + duration=14.64, + channel_ids=[0], + transforms=None, + ), + 'target_codes': TemporalArray( + array=Array(storage_type='memory_npy', storage_path='', storage_key='', shape=[8, 149]), + temporal_dim=-1, + frame_shift=0.046511627906976744, + start=0, + ), + 'context_codes': TemporalArray( + array=Array(storage_type='memory_npy', storage_path='', storage_key='', shape=[8, 316]), + temporal_dim=-1, + frame_shift=0.046511627906976744, + start=0, + ), + 'shard_origin': '/cuts/cuts.000001.jsonl.gz', + 'shar_epoch': 0, + 'tokenizer_names': ['english_phoneme'], + }, + ) + return cut + + +def test_cut_cer_filter(cut_example): + f = CERFilter(0.4) + assert f(cut_example) == True + + f = CERFilter(0.01) + assert f(cut_example) == False + + f = CERFilter(float("inf")) + assert f(cut_example) == True + + +def test_cut_context_speaker_similarity_filter(cut_example): + f = ContextSpeakerSimilarityFilter(0.6) + assert f(cut_example) == True + + f = ContextSpeakerSimilarityFilter(0.95) + assert f(cut_example) == False + + f = ContextSpeakerSimilarityFilter(-1) + assert f(cut_example) == True + + +def test_cut_validation_status_filter(cut_example): + f = ValidationStatusFilter("pass") + assert f(cut_example) == True + + f = ValidationStatusFilter("wrong_text") + assert f(cut_example) == False + + f = ValidationStatusFilter("any_other_status") + assert f(cut_example) == False From 85a41917983dc0dee03659905bf093d578a76584 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 24 Jul 2025 16:33:16 -0400 Subject: [PATCH 067/113] Infer updates: G2P_Prob=1, Add Violin Plots, ASR Default, Cleanups (#13889) * set g2p to 1 Signed-off-by: Jason * clean up branch Signed-off-by: Jason * add fix for nemo file; and include check for None for targets Signed-off-by: Jason * add violin plots to inference script; cleean up call to run_inference() Signed-off-by: Jason * undo unintended change Signed-off-by: Jason * Update scripts/magpietts/infer_and_evaluate.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jason * update asr default Signed-off-by: Jason --------- Signed-off-by: Jason Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- scripts/magpietts/evalset_config.py | 20 ++ scripts/magpietts/evaluate_generated_audio.py | 6 +- scripts/magpietts/infer_and_evaluate.py | 190 ++++++++++-------- 3 files changed, 127 insertions(+), 89 deletions(-) diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index a7b88d8eb5cd..01bba47c6b2b 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -324,4 +324,24 @@ 'audio_dir' : '/', 'feature_dir' : None, }, + 'j_riva_digits': { + 'manifest_path' : '/home/jasoli/data_prime/manifests/hard-digits-RivaEnContext.ndjson', + 'audio_dir' : '/', + 'feature_dir' : None, + }, + 'j_riva_letters': { + 'manifest_path' : '/home/jasoli/data_prime/manifests/hard-letters-RivaEnContext.ndjson', + 'audio_dir' : '/', + 'feature_dir' : None, + }, + 'j_riva_money': { + 'manifest_path' : '/home/jasoli/data_prime/manifests/hard-digits-RivaEnContext.ndjson', + 'audio_dir' : '/', + 'feature_dir' : None, + }, + 'j_riva_short': { + 'manifest_path' : '/home/jasoli/data_prime/manifests/hard-short-RivaEnContext.ndjson', + 'audio_dir' : '/', + 'feature_dir' : None, + }, } \ No newline at end of file diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index e28ca5002c8f..3553ac98fee8 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -97,7 +97,7 @@ def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_second n_samples = len(audio_np) min_samples = round(min_seconds * sampling_rate) - + if n_samples < min_samples: print(f"Padding audio from {n_samples/sampling_rate} seconds to {min_samples/sampling_rate} seconds") padding_needed = min_samples - n_samples @@ -107,7 +107,7 @@ def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_second @contextmanager def nemo_log_level(level): """ - A context manager that temporarily sets the logging level for the NeMo logger + A context manager that temporarily sets the logging level for the NeMo logger and restores the original level when the context manager is exited. Args: @@ -148,7 +148,7 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo device = "cuda" if language == "en": - if asr_model_name in ["nvidia/parakeet-tdt-1.1b", "nvidia/parakeet-ctc-0.6b", "stt_en_conformer_transducer_large"]: + if asr_model_name.startswith("nvidia/") or asr_model_name in ["stt_en_conformer_transducer_large"]: asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=asr_model_name) else: raise ValueError(f"ASR model {asr_model_name} not supported") diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 4604083edd7c..c4132c10093f 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -18,6 +18,9 @@ import os import shutil import time +from typing import List +from pathlib import Path +from functools import partial import scripts.magpietts.evalset_config as evalset_config import scripts.magpietts.evaluate_generated_audio as evaluate_generated_audio @@ -27,10 +30,13 @@ import torch from omegaconf.omegaconf import OmegaConf, open_dict from PIL import Image +import matplotlib.pyplot as plt +import pandas as pd from nemo.collections.asr.parts.utils.manifest_utils import read_manifest from nemo.collections.tts.data.text_to_speech_dataset import MagpieTTSDataset from nemo.collections.tts.models import MagpieTTSModel +from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0.90): metrics = {} @@ -113,6 +119,50 @@ def delete_old_generated_files(output_dir): for f in glob.glob(f"{output_dir}/predicted_audio*.wav"): os.remove(f) +def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: str): + # Create dataframe from list of dicts + df = pd.DataFrame(metrics) + + # Plot the violin plots for all DataFrames side by side + num_columns = len(metric_keys) + width = num_columns * 5 + fig, axs = plt.subplots(1, num_columns, figsize=(width, 4)) + + for i, column in enumerate(metric_keys): + assert column in df + # Create empty lists to store the parts objects for each DataFrame + # Plot the violin plots for each DataFrame + axs[i].violinplot( + df[column], showmedians=True, positions=[i], widths=0.5 + ) + + axs[i].set_title(column) + axs[i].set_xticks([i]) + axs[i].set_xticklabels([column]) + axs[i].grid(True, linestyle="dotted") + + # Calculate and display the mean value for each DataFrame + mean = df[column].mean() + sem = df[column].sem() + axs[i].plot( + i, + mean, + "o", + color="red", + markersize=4, + label="Mean (95%CI)" + ) + + label_numeric = f"{mean:.2f}±{1.96 * sem:.2f}" + axs[i].text(i + 0.06, mean, label_numeric, ha="center", va="top") + + # Create a single legend for all subplots + handles, labels = axs[0].get_legend_handles_labels() + fig.legend(handles, labels, loc="upper left") + + plt.tight_layout() + plt.savefig(output_png, format="png", bbox_inches="tight") + def run_inference( hparams_file, @@ -143,6 +193,7 @@ def run_inference( hparams_file_from_wandb=False, log_exp_name=False, compute_fcd=False, + violin_plot_metrics=['cer', 'pred_context_ssim'] ): # Load model if hparams_file is not None and checkpoint_file is not None: @@ -177,7 +228,7 @@ def run_inference( if cfg_sample_rate is not None and cfg_sample_rate != model.sample_rate: raise ValueError("Sample rate in config and model do not match") - + print("Loaded weights.") model.cuda() model.eval() @@ -246,7 +297,7 @@ def run_inference( pred_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") os.makedirs(pred_audio_dir, exist_ok=True) delete_old_generated_files(pred_audio_dir) - + test_dataset = MagpieTTSDataset( dataset_meta=dataset_meta, sample_rate=model.sample_rate, @@ -273,6 +324,14 @@ def run_inference( assert len(test_dataset) == len(manifest_records), f"Dataset length and manifest length should be the same. Dataset length: {len(test_dataset)}, Manifest length: {len(manifest_records)}" test_dataset.text_tokenizer = model.tokenizer + # Set phoneme prob = 1 for g2p + g2p = None + if isinstance(model.tokenizer, AggregatedTTSTokenizer): + g2p = model.tokenizer.tokenizers["english_phoneme"].g2p + elif isinstance(model.tokenizer, IPATokenizer): + g2p = model.tokenizer.g2p + if g2p is not None: + g2p.phoneme_probability = 1.0 test_dataset.text_conditioning_tokenizer = model.text_conditioning_tokenizer test_data_loader = torch.utils.data.DataLoader( @@ -373,6 +432,9 @@ def run_inference( f.write(data) print(f"Wrote metrics for {checkpoint_name} and {dataset} to {all_experiment_csv}") + output_png_file = Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png" + create_violin_plots(filewise_metrics, violin_plot_metrics, output_png_file) + # Clean up temporary codec files for codes_file in codec_file_paths: os.remove(codes_file) @@ -400,7 +462,7 @@ def run_inference( data += "\n" f.write(data) print(f"Wrote metrics with CI for {checkpoint_name} and {dataset} to {all_experiment_csv_with_ci}") - + measurements = [m['ssim_pred_context_avg'] for m in metrics_n_repeated] ssim_current = np.mean(measurements) @@ -445,15 +507,16 @@ def main(): parser.add_argument('--topk', type=int, default=80) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--sv_model', type=str, default="titanet") # titanet, wavlm - parser.add_argument('--asr_model_name', type=str, default="stt_en_conformer_transducer_large") # stt_en_conformer_transducer_large, nvidia/parakeet-ctc-0.6b + parser.add_argument('--asr_model_name', type=str, default="nvidia/parakeet-tdt-1.1b") # stt_en_conformer_transducer_large, nvidia/parakeet-ctc-0.6b parser.add_argument('--num_repeats', type=int, default=1) parser.add_argument('--confidence_level', type=float, default=0.95) parser.add_argument('--legacy_codebooks', action='store_true') parser.add_argument('--clean_up_disk', action='store_true') - parser.add_argument('--cer_target', type=float, default=1.0) - parser.add_argument('--ssim_target', type=float, default=0.) + parser.add_argument('--cer_target', type=float, default=None) + parser.add_argument('--ssim_target', type=float, default=None) parser.add_argument('--log_exp_name', action='store_true', help="Include the experiment name (derived from the checkpoint path) in the output folder name.") parser.add_argument('--disable_fcd', action='store_true', help="Disable Frechet Codec Distance computation") + parser.add_argument('--violin_plot_metrics', type=str, nargs='*', default=['cer','pred_context_ssim'], help="Which metrics to add the violin plot.") args = parser.parse_args() # FCD computation is enabled by default, disabled only when --disable_fcd is specified @@ -466,6 +529,36 @@ def main(): if args.apply_prior_to_layers is not None: apply_prior_to_layers = [int(l.strip()) for l in args.apply_prior_to_layers.split(",")] + run_inference_w_args = partial( + run_inference, + datasets=args.datasets.split(","), + out_dir=args.out_dir, + temperature=args.temperature, + topk=args.topk, + codecmodel_path=args.codecmodel_path, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + batch_size=args.batch_size, + sv_model=args.sv_model, + asr_model_name=args.asr_model_name, + num_repeats=args.num_repeats, + apply_attention_prior=args.apply_attention_prior, + attention_prior_epsilon=args.attention_prior_epsilon, + attention_prior_lookahead_window=args.attention_prior_lookahead_window, + estimate_alignment_from_layers=estimate_alignment_from_layers, + apply_prior_to_layers=apply_prior_to_layers, + start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, + confidence_level=args.confidence_level, + use_local_transformer=args.use_local_transformer, + maskgit_n_steps=args.maskgit_n_steps, + legacy_codebooks=args.legacy_codebooks, + clean_up_disk=args.clean_up_disk, + hparams_file_from_wandb=args.hparams_file_from_wandb, + log_exp_name=args.log_exp_name, + compute_fcd=compute_fcd, + violin_plot_metrics=args.violin_plot_metrics + ) + # Mode 1: Run inference from provided hparams and checkpoint files if (args.hparams_files is not None) and (args.checkpoint_files is not None) and (args.hparams_files != "null") and (args.checkpoint_files != "null"): hparam_files = args.hparams_files.split(",") @@ -474,70 +567,20 @@ def main(): print("Running inference for checkpoint files: ", checkpoint_files) assert len(hparam_files) == len(checkpoint_files), "Number of hparams files and checkpoint files should be the same." for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): - cer, ssim = run_inference( + cer, ssim = run_inference_w_args( hparams_file=hparams_file, checkpoint_file=checkpoint_file, nemo_file=None, - datasets=args.datasets.split(","), - out_dir=args.out_dir, - temperature=args.temperature, - topk=args.topk, - codecmodel_path=args.codecmodel_path, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - batch_size=args.batch_size, - sv_model=args.sv_model, - asr_model_name=args.asr_model_name, - num_repeats=args.num_repeats, - apply_attention_prior=args.apply_attention_prior, - attention_prior_epsilon=args.attention_prior_epsilon, - attention_prior_lookahead_window=args.attention_prior_lookahead_window, - estimate_alignment_from_layers=estimate_alignment_from_layers, - apply_prior_to_layers=apply_prior_to_layers, - start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, - confidence_level=args.confidence_level, - use_local_transformer=args.use_local_transformer, - maskgit_n_steps=args.maskgit_n_steps, - legacy_codebooks=args.legacy_codebooks, - clean_up_disk=args.clean_up_disk, - hparams_file_from_wandb=args.hparams_file_from_wandb, - log_exp_name=args.log_exp_name, - compute_fcd=compute_fcd ) return # Mode 2: Run inference from a .nemo file elif args.nemo_files: print(f"Running inference for nemo file: {args.nemo_files}") for nemo_file in args.nemo_files.split(","): - cer, ssim = run_inference( + cer, ssim = run_inference_w_args( hparams_file=None, checkpoint_file=None, nemo_file=nemo_file, - datasets=args.datasets.split(","), - out_dir=args.out_dir, - temperature=args.temperature, - topk=args.topk, - codecmodel_path=args.codecmodel_path, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - batch_size=args.batch_size, - sv_model=args.sv_model, - asr_model_name=args.asr_model_name, - num_repeats=args.num_repeats, - apply_attention_prior=args.apply_attention_prior, - attention_prior_epsilon=args.attention_prior_epsilon, - attention_prior_lookahead_window=args.attention_prior_lookahead_window, - estimate_alignment_from_layers=estimate_alignment_from_layers, - apply_prior_to_layers=apply_prior_to_layers, - start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, - confidence_level=args.confidence_level, - use_local_transformer=args.use_local_transformer, - maskgit_n_steps=args.maskgit_n_steps, - legacy_codebooks=args.legacy_codebooks, - clean_up_disk=args.clean_up_disk, - hparams_file_from_wandb=args.hparams_file_from_wandb, - log_exp_name=args.log_exp_name, - compute_fcd=compute_fcd ) # Mode 3: Discover and run experiments from a base directory # Mount DRACO_EXP_DIR to BASE_EXP_DIR as follows: @@ -577,35 +620,10 @@ def main(): print("Copied hparams file.") print("Hparams file path: ", hparams_copy_path) print("Checkpoint file path: ", checkpoint_copy_path) - run_inference( + run_inference_w_args( hparams_copy_path, checkpoint_copy_path, nemo_file=None, - datasets=args.datasets.split(","), - out_dir=args.out_dir, - temperature=args.temperature, - topk=args.topk, - codecmodel_path=args.codecmodel_path, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - batch_size=args.batch_size, - sv_model=args.sv_model, - asr_model_name=args.asr_model_name, - num_repeats=args.num_repeats, - apply_attention_prior=args.apply_attention_prior, - attention_prior_epsilon=args.attention_prior_epsilon, - attention_prior_lookahead_window=args.attention_prior_lookahead_window, - estimate_alignment_from_layers=estimate_alignment_from_layers, - apply_prior_to_layers=apply_prior_to_layers, - start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, - confidence_level=args.confidence_level, - use_local_transformer=args.use_local_transformer, - maskgit_n_steps=args.maskgit_n_steps, - legacy_codebooks=args.legacy_codebooks, - clean_up_disk=args.clean_up_disk, - hparams_file_from_wandb=args.hparams_file_from_wandb, - log_exp_name=args.log_exp_name, - compute_fcd=compute_fcd ) else: parser.error( @@ -614,9 +632,9 @@ def main(): "2. --nemo_file\n" "3. --base_exp_dir to discover experiments" ) - if cer > float(args.cer_target): + if args.cer_target is not None and cer > float(args.cer_target): raise ValueError() - if ssim < float(args.ssim_target): + if args.ssim_target is not None and ssim < float(args.ssim_target): raise ValueError() From f5510e9bf5ec6819dab4d5acb7897764a19da736 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 7 Aug 2025 16:49:04 -0400 Subject: [PATCH 068/113] Quick PR to disable broken tests on magpie dev branch for now (#14429) * disable broken tests on dev branch Signed-off-by: Jason * remove duplicated test Signed-off-by: Jason --------- Signed-off-by: Jason --- .github/workflows/install-test.yml | 6 +-- .../common/test_lhotse_dataloading.py | 37 ++----------------- 2 files changed, 7 insertions(+), 36 deletions(-) diff --git a/.github/workflows/install-test.yml b/.github/workflows/install-test.yml index a580fcaf8f26..bd5919701a99 100644 --- a/.github/workflows/install-test.yml +++ b/.github/workflows/install-test.yml @@ -2,8 +2,8 @@ name: CI-Install-Check on: pull_request: - paths: - - "**" + branches: + - main concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -40,7 +40,7 @@ jobs: export NEMO_TAG export NEMO_REPO export INSTALL_DIR=$(pwd) - + bash reinstall.sh --library all --mode install fi diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index 8e941e4aefad..cfbb49ba7b41 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -344,7 +344,7 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path): assert batches[4]["audio"].shape == (4, 8000) # exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts - +@pytest.mark.pleasefixme def test_dataloader_from_lhotse_cuts_pad_min_duration(cutset_path: Path): config = OmegaConf.create( { @@ -1938,7 +1938,7 @@ def test_multimodal_text_audio_dataloading_randomized_round_robin_strategy( assert torch.is_tensor(ex.answer_ids) assert torch.is_tensor(ex.mask) - +@pytest.mark.pleasefixme def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): config = OmegaConf.create( { @@ -1967,36 +1967,7 @@ def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: assert isinstance(cut, MixedCut) assert -5.0 < cut.tracks[1].snr < 5.0 - -def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): - config = OmegaConf.create( - { - "cuts_path": str(cutset_path), - "noise_path": str(nemo_manifest_path), - "noise_mix_prob": 1.0, - "noise_snr": [-5.0, 5.0], - "batch_size": 2, - "seed": 0, - "shard_seed": 0, - } - ) - dl = get_lhotse_dataloader_from_config( - config=config, - global_rank=0, - world_size=1, - dataset=Identity(), - ) - batch = next(iter(dl)) - assert isinstance(batch, CutSet) - assert len(batch) == 2 - cut = batch[0] - assert isinstance(cut, MixedCut) - assert -5.0 < cut.tracks[1].snr < 5.0 - cut = batch[1] - assert isinstance(cut, MixedCut) - assert -5.0 < cut.tracks[1].snr < 5.0 - - +@pytest.mark.pleasefixme def test_dataloader_with_noise_lhotse_jsonl(cutset_path: Path): config = OmegaConf.create( { @@ -2025,7 +1996,7 @@ def test_dataloader_with_noise_lhotse_jsonl(cutset_path: Path): assert isinstance(cut, MixedCut) assert -5.0 < cut.tracks[1].snr < 5.0 - +@pytest.mark.pleasefixme def test_dataloader_with_noise_nemo_tar(cutset_path: Path, nemo_tarred_manifest_path_multi: Path): noise_json, noise_tar = nemo_tarred_manifest_path_multi config = OmegaConf.create( From 3df09e86ad126693f572d38acaef345815bbc5c6 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Tue, 12 Aug 2025 07:43:45 -0700 Subject: [PATCH 069/113] Magpietts 2503 po july2025 (#14393) * change to eval mode while generating Signed-off-by: Shehzeen Hussain * add cer thresholds for GRPO Signed-off-by: Shehzeen Hussain * added LT to GRPO Signed-off-by: Shehzeen Hussain * remove debug statements Signed-off-by: Shehzeen Hussain * added comments, disabled KV caching for local transformer during GRPO Signed-off-by: Shehzeen Hussain --------- Signed-off-by: Shehzeen Hussain --- nemo/collections/tts/models/magpietts.py | 8 ++- .../magpietts_preference_optimization.py | 62 ++++++++++++++++--- scripts/magpietts/README_magpie_po.md | 1 + 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 1228d34581db..976acd95119f 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -675,9 +675,9 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, codes = codes[:actual_batch_size] return codes - def local_transformer_sample_autoregressive(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0): + def local_transformer_sample_autoregressive(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0, use_kv_cache=True): # dec_output: (B, E) - self.local_transformer.reset_cache(use_cache=True) + self.local_transformer.reset_cache(use_cache=use_kv_cache) dec_output = dec_output.unsqueeze(1) # (B, 1, E) local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) all_preds = [] @@ -1554,6 +1554,7 @@ def infer_batch( start_prior_after_n_audio_steps=10, compute_all_heads_attn_maps=False, use_local_transformer_for_inference=False, + use_LT_kv_cache=True, maskgit_n_steps=3 ): with torch.no_grad(): @@ -1714,7 +1715,8 @@ def infer_batch( unfinished_items=unfinished_items, finished_items=finished_items, use_cfg=use_cfg, - cfg_scale=cfg_scale + cfg_scale=cfg_scale, + use_kv_cache=use_LT_kv_cache, ) elif self.local_transformer_type == LocalTransformerType.MASKGIT: audio_codes_next = self.local_transformer_sample_maskgit( diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index d29cf5b1c48c..3ed9107c4352 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -432,7 +432,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): use_pesq = self.cfg.get('use_pesq', False) if use_pesq: - # import ipdb; ipdb.set_trace() assert HAVE_TORCHAUDIO, "torchaudio is required for PESQ reward" self.squim_objective_model = SQUIM_OBJECTIVE.get_model() @@ -441,6 +440,13 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): raise ValueError(f"Received loss_type of {self.loss_type}, but the model only accepts one of ['grpo', 'dr_grpo']") self.scale_rewards = self.cfg.get('scale_rewards', True) self.max_decoder_steps = self.cfg.get('max_decoder_steps', 430) + # If the best record in the group is above this threshold, we will not use that group for training + # Setting this to 1.0, because we clamp the ASR rewards to be in [0, 1] for OnlinePO + self.best_cer_threshold = self.cfg.get('best_cer_threshold', 1.0) + # If the worst record in the group exceeds this threshold, we will not use that group for training + # Setting this to 1.0, because we clamp the ASR rewards to be in [0, 1] for OnlinePO + self.worst_cer_threshold = self.cfg.get('worst_cer_threshold', 1.0) + def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = super().state_dict(destination, prefix, keep_vars) @@ -476,7 +482,7 @@ def repeat_items_in_batch(self, batch, num_repeats): repeated_batch[key] = repeated_value return repeated_batch - def generate_and_reward(self, batch, num_generations_per_item, mode='train'): + def generate_and_reward(self, batch, num_generations_per_item, mode='train', use_local_transformer_for_inference=False): batch_repeated = self.repeat_items_in_batch(batch, num_generations_per_item) temperature = self.cfg.get('inference_temperature', 0.7) topk = self.cfg.get('inference_topk', 80) @@ -488,14 +494,16 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): # Randomly set use_cfg based on the given probability use_cfg = random.random() < self.cfg.inference_cfg_prob cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) - print("use_cfg", use_cfg) + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( batch_repeated, max_decoder_steps=self.max_decoder_steps, temperature=temperature, topk=topk, use_cfg=use_cfg, - cfg_scale=cfg_scale + cfg_scale=cfg_scale, + use_local_transformer_for_inference=use_local_transformer_for_inference, + use_LT_kv_cache=False, # We don't use KV caching for local transformer in GRPO due to issues. ) predicted_audio_paths = [] audio_durations = [] @@ -583,11 +591,15 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): mean_ssim_dataset = self.cfg.get("mean_ssim_dataset", 0.6) # SSIM equal to this value will have reward of 0.5 all_groups_mean_reward = 0.0 all_groups_std_reward = 0.0 + group_validities = [] for group_idx in range(num_groups): group_start_idx = group_idx * num_generations_per_item group_end_idx = group_start_idx + num_generations_per_item group_rewards = [] mean_reward = 0 + is_group_valid = True + group_best_cer = 1.0 + group_worst_cer = 0.0 for idx in range(group_start_idx, group_end_idx): # Lower CER and higher speaker similarity is better, means high reward # Higher pesq is better, means high reward @@ -597,6 +609,9 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): item_cer = min( max(item_cer, 0.0), 1.0) item_ssim = max( min(item_ssim, best_ssim_achievable), 0.0) item_pesq = batch_metrics[idx]['pesq'] + group_best_cer = min(group_best_cer, item_cer) + group_worst_cer = max(group_worst_cer, item_cer) + if item_cer <= mean_cer_dataset: cer_reward = 0.5 + 0.5 * (mean_cer_dataset - item_cer) / mean_cer_dataset # 0.5 to 1 else: @@ -621,6 +636,17 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): batch_metrics[idx]['pesq_reward'] = pesq_reward mean_reward += batch_metrics[idx]['reward'] group_rewards.append(batch_metrics[idx]['reward']) + + if (group_best_cer > self.best_cer_threshold): + is_group_valid = False + print(f"Group {group_idx} has best CER {group_best_cer} which is above the threshold {self.best_cer_threshold}. Group is invalid.") + + if (group_worst_cer > self.worst_cer_threshold): + is_group_valid = False + print(f"Group {group_idx} has worst CER {group_worst_cer} which is above the threshold {self.worst_cer_threshold}. Group is invalid.") + + for _ in range(num_generations_per_item): + group_validities.append(is_group_valid) mean_reward /= num_generations_per_item std_reward = np.std(group_rewards) @@ -636,6 +662,8 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): advantages = [x['advantage'] for x in batch_metrics] advantages = torch.tensor(advantages, device=self.device) print("Mean reward: ", all_groups_mean_reward) + + group_validities = torch.tensor(group_validities, device=self.device) return { 'mean_reward': torch.tensor(all_groups_mean_reward, device=self.device), 'std_reward': torch.tensor(all_groups_std_reward, device=self.device), @@ -644,16 +672,32 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train'): 'predicted_codes': predicted_codes, 'predicted_codes_lens': predicted_codes_lens, 'advantages': advantages, + 'group_validities': group_validities, } + def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): use_kv_cache_during_online_po = self.cfg.get("use_kv_cache_during_online_po", False) if use_kv_cache_during_online_po: self.use_kv_cache_for_inference = True self.decoder.reset_cache(use_cache=True) + use_local_transformer_for_inference = False + logits_key = 'logits' + use_local_transformer_prob = self.cfg.get('use_local_transformer_prob', 0.0) + if use_local_transformer_prob > 0.0 and mode == 'train': + use_local_transformer_for_inference = random.random() < use_local_transformer_prob + logits_key = 'local_transformer_logits' + with torch.no_grad(): - generated_codes_and_metrics = self.generate_and_reward(batch, n_generations_per_item, mode) + self.eval() + generated_codes_and_metrics = self.generate_and_reward( + batch, + n_generations_per_item, + mode, + use_local_transformer_for_inference=use_local_transformer_for_inference + ) + self.train() if use_kv_cache_during_online_po: self.use_kv_cache_for_inference = False @@ -691,16 +735,18 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): reference_codebook_loss_mask = reference_model_output['loss_mask'][:,codebook_idx,:] if not self.reference_free else None si = codebook_idx * self.num_all_tokens_per_codebook ei = si + self.num_all_tokens_per_codebook - codebook_logits = policy_model_outputs['logits'][:, :, si:ei] # B, T, C + + codebook_logits = policy_model_outputs[logits_key][:, :, si:ei] # B, T, C codebook_labels = batch_repeated['audio_codes'][:,codebook_idx,1:] per_token_codebook_log_probs = self._get_per_token_logps(codebook_logits, codebook_labels, policy_codebook_loss_mask) per_token_loss = -(torch.exp(per_token_codebook_log_probs - per_token_codebook_log_probs.detach()) * advantages.unsqueeze(1)) - + group_validities = generated_codes_and_metrics['group_validities'] # B * n_generations_per_item + per_token_loss = per_token_loss * group_validities.unsqueeze(1) # B, T if not self.reference_free: with torch.no_grad(): - ref_codebook_logits = reference_model_output['logits'][:, :, si:ei] + ref_codebook_logits = reference_model_output[logits_key][:, :, si:ei] per_token_ref_codebook_log_probs = self._get_per_token_logps(ref_codebook_logits, codebook_labels, reference_codebook_loss_mask) # https://github.com/huggingface/trl/blob/ffcb9f4aee725a2bd072d0387afe68a4b1c7967c/trl/trainer/grpo_trainer.py#L703 per_token_codebook_kl = torch.exp(per_token_ref_codebook_log_probs - per_token_codebook_log_probs) - (per_token_ref_codebook_log_probs - per_token_codebook_log_probs) - 1 diff --git a/scripts/magpietts/README_magpie_po.md b/scripts/magpietts/README_magpie_po.md index d2dab048e1ec..fee1b575c124 100644 --- a/scripts/magpietts/README_magpie_po.md +++ b/scripts/magpietts/README_magpie_po.md @@ -150,6 +150,7 @@ To train with GRPO, we use a similar training command as the base model training +model.num_generations_per_item=12 \ # 12 samples generated for each item and we compute reward for each +model.reference_free=true \ # Reference free means we dont use KL loss term. Only optimize for rewards +model.inference_cfg_prob=0.0 \ # fraction of generations generated using CFG. Can set > 0.0 if we want to optimize for both CFG and non CFG modes of generation ++model.use_local_transformer_prob=0.5 \ # fraction of generations generated using Local Transformer. Set it between 0.0 and 1.0 to improve both LT outputs and non LT outputs for models with an LT +model.inference_cfg_scale=2.5 \ # CFG scale for samples generated using CFG +model.cer_reward_weight=0.33 \ # weightage of CER reward in the overall reward +model.ssim_reward_weight=0.33 \ # weightage of SSIM reward in the overall reward From af12c35bd5d267f8b8ca41d826adfe0515e1c094 Mon Sep 17 00:00:00 2001 From: Subhankar Ghosh Date: Thu, 14 Aug 2025 17:01:24 -0400 Subject: [PATCH 070/113] [magpietts] Magpietts small and Attention changes, Evaluation (#14418) * Add small magpie config, inference changes * Bug fix * Inference and Evaluation cleanup and standardization Signed-off-by: subhankar-ghosh * Bug fixes and incorporating reviews Signed-off-by: subhankar-ghosh * Add toggle between training and inference for applying attn prior to x attn Signed-off-by: subhankar-ghosh * Using pytorch training flag Signed-off-by: subhankar-ghosh --------- Signed-off-by: subhankar-ghosh Co-authored-by: Jason --- .../magpietts_lhotse_dc_en_tiny.yaml | 191 ++++++++++ nemo/collections/tts/models/magpietts.py | 16 +- .../tts/modules/transformer_2501.py | 24 +- scripts/magpietts/README.md | 46 +++ scripts/magpietts/evalset_config.py | 347 ++---------------- scripts/magpietts/infer_and_evaluate.py | 169 ++++++--- 6 files changed, 398 insertions(+), 395 deletions(-) create mode 100644 examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml create mode 100644 scripts/magpietts/README.md diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml new file mode 100644 index 000000000000..b57ba7c50a45 --- /dev/null +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml @@ -0,0 +1,191 @@ +name: MagpieTTS-EN-Lhotse + +quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. + +model: + use_lhotse: true + model_type: "decoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer + use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. + context_duration_min: 5.0 + context_duration_max: 5.0 + load_cached_codes_if_available: true + prior_scaling_factor: 0.5 + prior_end_step: 12_000 + prior_scaledown_start_step: 8_000 # Prior will always be on before this step. + indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. + alignment_loss_scale: 0.002 + embedding_dim: 512 + codecmodel_path: ??? + cfg_unconditional_prob: 0.1 # enable classifier-free guidance during traing by dropping out conditionals with this probability + + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10_000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50_000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 3 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + + text_tokenizers: # Add more languages for multi-lingual TTS + english_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: false + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + + train_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration : ??? # in seconds. Adjust based on your GPU memory. + quadratic_duration: ${quadratic_duration} + use_bucketing: true + num_buckets: 10 + bucket_buffer_size: 20_000 + shuffle_buffer_size: 20_000 + num_cuts_for_bins_estimate: 20_000 + shard_seed: "trng" + drop_last: true + shuffle: true + num_workers: 4 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + validation_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. + quadratic_duration: ${quadratic_duration} + use_bucketing: false + force_finite: true + drop_last: false + shuffle: false + num_workers: 4 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + encoder: + n_layers: 6 + d_model: 512 + d_ffn: 2048 + sa_n_heads: 8 + kernel_size: 1 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: true + apply_norm_out: true + max_length_causal_mask: 2_048 + use_learnable_pos_emb: true + + decoder: + n_layers: 12 + d_model: 512 + d_ffn: 2048 + sa_n_heads: 8 + kernel_size: 1 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: true + xa_d_memory: 512 + xa_n_heads: 1 + xa_d_head: 128 + is_causal: true + apply_norm_to_cond: true + apply_norm_out: true + max_length_causal_mask: 2_048 + use_learnable_pos_emb: true + + optim: + _target_: torch.optim.AdamW + lr: 1e-4 + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: bf16-mixed + max_steps: ??? + accumulate_grad_batches: 1 + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + limit_train_batches: 1_000 + val_check_interval: 1_000 + num_sanity_val_steps: 0 + benchmark: false + use_distributed_sampler: false # required because Lhotse has its own handling + gradient_clip_val: 2.5 + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_wandb_logger: false + wandb_logger_kwargs: + entity: null + project: null + group: null + name: ${name} + resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}-{epoch}' + resume_if_exists: true + resume_ignore_no_checkpoint: true diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 976acd95119f..b96225daa025 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -1472,7 +1472,7 @@ def get_most_attended_text_timestep(self, alignment_attention_scores, last_atten def construct_inference_prior(self, prior_epsilon, cross_attention_scores, text_lens, text_time_step_attended, attended_timestep_counter, - unfinished_texts, finished_texts_counter, end_indices, batch_size): + unfinished_texts, finished_texts_counter, end_indices, lookahead_window_size, batch_size): # Attn prior for the next timestep _attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + prior_epsilon _attn_prior = _attn_prior.to(cross_attention_scores.device) @@ -1483,17 +1483,16 @@ def construct_inference_prior(self, prior_epsilon, cross_attention_scores, # Very short sentences, No Prior _attn_prior[bidx, 0, :] = 1.0 else: - # _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-2)] = 0.1 # Slight exposure to history for better pronounciation. Not very important. - _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-1)] = 0.2 # Slight exposure to history for better pronounciation. Not very important. - _attn_prior[bidx, 0, text_time_step_attended[bidx]] = 0.8 # Slightly bias to continue moving forward. Not very important. - _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+1, _text_len - 1) ] = 1.0 - _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+2, _text_len - 1) ] = 0.8 + _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-1)] = 1.0 # Slight exposure to history for better pronounciation. Not very important. + _attn_prior[bidx, 0, text_time_step_attended[bidx]] = 1.0 # Slightly bias to continue moving forward. Not very important. + for ind in range(1, lookahead_window_size + 1): + _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+ind, _text_len - 1) ] = 1.0 # Penalize timesteps that have been attended to more than 10 times for _timestep in attended_timestep_counter[bidx]: if attended_timestep_counter[bidx][_timestep] >= 10: # This means the timestep has been attended to more than 10 times (To avoid getting stuck) - _attn_prior[bidx, 0, _timestep] = prior_epsilon + _attn_prior[bidx, 0, :_timestep+1] = prior_epsilon unfinished_texts[bidx] = False if text_time_step_attended[bidx] < text_lens[bidx] - 3: @@ -1501,7 +1500,7 @@ def construct_inference_prior(self, prior_epsilon, cross_attention_scores, if bidx not in end_indices: unfinished_texts[bidx] = True - if text_time_step_attended[bidx] >= text_lens[bidx] - 5 or bidx in end_indices: + if text_time_step_attended[bidx] >= text_lens[bidx] - 2 or bidx in end_indices: if bidx not in finished_texts_counter: finished_texts_counter[bidx] = 0 @@ -1698,6 +1697,7 @@ def infer_batch( unfinished_texts=unfinished_texts, finished_texts_counter=finished_texts_counter, end_indices=end_indices, + lookahead_window_size=lookahead_window_size, batch_size=batch_size ) diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index 59b1986ac348..af22c0d1380c 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -231,15 +231,21 @@ def attn_naive( if attn_prior is not None: eps = torch.finfo(attn_prior.dtype).tiny attn_prior = attn_prior[:, :T] # trim for inference - attn_prior = attn_prior[:, None] - attn_prior_log = torch.log(attn_prior + eps) - attn_score_log = F.log_softmax(attn_score, dim=-1) + attn_prior_log - if self.make_prior_window_strict: - # Make sure attention scores are lowest (eps) where prior is zero. - min_score = torch.log(torch.tensor(eps)).to(attn_score_log.device) - attn_score_log = attn_score_log.masked_fill(attn_prior == 0, min_score) # Wherever prior is zero, set scores to eps. - attn_score_log = torch.clamp(attn_score_log, min=min_score) # Make sure scores are not less than eps. - attn_prob = F.softmax(attn_score_log, dim=-1) + attn_prior = attn_prior[:, None] + eps + # Use PyTorch's built-in training flag to branch behavior + if self.training: + attn_prior_log = torch.log(attn_prior) + attn_score_log = F.log_softmax(attn_score, dim=-1) + attn_prior_log + if self.make_prior_window_strict: + # Make sure attention scores are lowest (eps) where prior is zero. + min_score = torch.log(torch.tensor(eps)).to(attn_score_log.device) + attn_score_log = attn_score_log.masked_fill(attn_prior == 0, min_score) # Wherever prior is zero, set scores to eps. + attn_score_log = torch.clamp(attn_score_log, min=min_score) # Make sure scores are not less than eps. + attn_prob = F.softmax(attn_score_log, dim=-1) + else: + attn_prob = F.softmax(attn_score, dim=-1) + attn_prob = attn_prob * attn_prior + attn_prob = attn_prob / (attn_prob.sum(dim=-1, keepdim=True)) # normalize else: attn_prob = F.softmax(attn_score, dim=-1) diff --git a/scripts/magpietts/README.md b/scripts/magpietts/README.md new file mode 100644 index 000000000000..ef99015e65c0 --- /dev/null +++ b/scripts/magpietts/README.md @@ -0,0 +1,46 @@ +# MagpieTTS Inference and Evaluation + +To evaluate any MagpieTTS checkpoint you trained follow the steps as shown below (INTERNAL ONLY): + +1) Mount the EOS cluster path `/lustre/fsw/llmservice_nemo_speechlm/data/TTS:/Data` + +All the needed manifests are here: `/lustre/fsw/llmservice_nemo_speechlm/data/TTS/evaluation_manifests` + +2) Run the following command: +``` +CKPT= +HPARAM= +CODEC= +OUT_DIR= + +python scripts/magpietts/infer_and_evaluate.py \ +--checkpoint_files ${CKPT} \ +--hparams_files ${HPARAM} \ +--codecmodel_path ${CODEC} \ +--out_dir ${OUT_DIR} \ +--use_cfg \ +--apply_attention_prior +``` + +**Test Sets** +The Datasets that we evaluate on are: + +- LibriTTS test clean +- LibriTTS seen +- VCTK +- RIVA Hard examples + +**Evaluation Metrics** + +- ASR of the generated speech is done using `nvidia/parakeet-tdt-1.1b` and then CER/WER is computed. +- Speaker Similarity using `titanet` + + + +# Using Lhotse Datasets in MagpieTTS + +Refer to [this file](./README_lhotse.md) for more information about using Lhotse Dataset of MagpieTTS. + +# Preference Alignment of MagpieTTS + +Refer to [this file](./README_magpie_po.md) for more information about preference alignment of MagpieTTS. \ No newline at end of file diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index 01bba47c6b2b..a36ec44db20c 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -12,336 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. dataset_meta_info = { - 'vctk': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths.json', - 'audio_dir' : '/datap/misc/Datasets/VCTK-Corpus', - 'feature_dir' : '/datap/misc/Datasets/VCTK-Corpus', - }, - 'riva_challenging': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/challengingLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json', - 'audio_dir' : '/datap/misc/Datasets/riva', - 'feature_dir' : '/datap/misc/Datasets/riva', - }, - 'riva_challenging_shehzeen': { - 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/challengingLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths_v2.json', - 'audio_dir' : '/Data/RivaData/riva', - 'feature_dir' : '/Data/RivaData/riva', - }, - 'rough_qwen': { - 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/rough.json', - 'audio_dir' : '/Data/RivaData/riva', - 'feature_dir' : '/Data/RivaData/riva', - 'tokenizer_names': ['qwen'], - 'load_cached_codes_if_available': False, - }, - 'riva_challenging_nozeros': { - # 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/riva_challenging_nozeros.json', - 'manifest_path': '/home/pneekhara/2023/SimpleT5NeMo/manifests/riva_challenging_filtered.json', - 'audio_dir' : '/datap/misc/Datasets/riva', - 'feature_dir' : '/datap/misc/Datasets/riva', - }, - 'libri_dev_clean_eval_large': { - 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/dev_clean_withContextAudioPaths_withTargetCodes_evalset_large.json', - 'audio_dir' : '/datap/misc/Datasets/LibriTTS', - 'feature_dir' : '/datap/misc/Datasets/LibriTTS', - }, - 'libri_dev_clean_eval_mid': { - 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/dev_clean_withContextAudioPaths_withTargetCodes_evalset_mid.json', - 'audio_dir' : '/datap/misc/Datasets/LibriTTS', - 'feature_dir' : '/datap/misc/Datasets/LibriTTS', - }, - 'libri_dev_clean_eval_tiny': { - 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/dev_clean_withContextAudioPaths_withTargetCodes_evalset_tiny.json', - 'audio_dir' : '/datap/misc/Datasets/LibriTTS', - 'feature_dir' : '/datap/misc/Datasets/LibriTTS', - }, - 'libri_val': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360_val.json', - 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - }, - 'libri_val_12.5': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360_val.json', - 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - 'load_cached_codes_if_available': False - }, - 'libri_val_shehzeen': { - 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/libri360_val.json', - 'audio_dir' : '/Data/LibriTTS', - 'feature_dir' : '/Data/LibriTTS', - }, - 'libri_unseen_test': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/test_clean_withContextAudioPaths.json', - 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - }, - 'libri_unseen_test_12.5': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/test_clean_withContextAudioPaths.json', - 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - 'load_cached_codes_if_available': False - }, - 'riva_val_text_context': { - 'manifest_path' : '/datap/misc/speechllm_codecdatasets/manifests/t5_exp/RivattsEnglishLindyRodney21fps_val_nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_phoneme_tts_TextContext.json', - 'audio_dir' : "/datap/misc/Datasets/riva/RivattsEnglish", - 'feature_dir' : '/', - }, - 'libri_unseen_test_shehzeen_phoneme': { - 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', - 'audio_dir' : '/Data/LibriTTS', - 'feature_dir' : '/Data/LibriTTS', - 'tokenizer_names': ['english_phoneme'], - }, - 'libri_unseen_test_shehzeen_sep_char': { - 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', - 'audio_dir' : '/Data/LibriTTS', - 'feature_dir' : '/Data/LibriTTS', - 'tokenizer_names': ['english_chartokenizer'], - }, - 'libri_unseen_test_shehzeen_shared_char': { - 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', - 'audio_dir' : '/Data/LibriTTS', - 'feature_dir' : '/Data/LibriTTS', - 'tokenizer_names': ['chartokenizer'], - }, - 'libri_unseen_test_shehzeen_shared_char_ipa': { - 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests_ipa/test_clean_withContextAudioPaths_ipa.json', - 'audio_dir' : '/Data/LibriTTS', - 'feature_dir' : '/Data/LibriTTS', - 'tokenizer_names': ['chartokenizer'], - }, - 'grpo_valset': { - 'manifest_path' : '/Data/DPOPairsInputDatav2/text_context_pairs_grpo_val_unseenspeakers.json', - 'audio_dir' : '/', - 'feature_dir' : '/', + 'riva_hard_digits': { + 'manifest_path' : '/Data/evaluation_manifests/hard-digits-path-corrected.ndjson', + 'audio_dir' : '/Data/RIVA-TTS', + 'feature_dir' : '/Data/RIVA-TTS', + }, + 'riva_hard_letters': { + 'manifest_path' : '/Data/evaluation_manifests/hard-letters-path-corrected.ndjson', + 'audio_dir' : '/Data/RIVA-TTS', + 'feature_dir' : '/Data/RIVA-TTS', + }, + 'riva_hard_money': { + 'manifest_path' : '/Data/evaluation_manifests/hard-money-path-corrected.ndjson', + 'audio_dir' : '/Data/RIVA-TTS', + 'feature_dir' : '/Data/RIVA-TTS', + }, + 'riva_hard_short': { + 'manifest_path' : '/Data/evaluation_manifests/hard-short-path-corrected.ndjson', + 'audio_dir' : '/Data/RIVA-TTS', + 'feature_dir' : '/Data/RIVA-TTS', }, - 'libri_unseen_test_shehzeen_sp': { - 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', - 'audio_dir' : '/Data/LibriTTS', - 'feature_dir' : '/Data/LibriTTS', - 'tokenizer_names': ['multilingual_sentencepiece'], + 'vctk': { + 'manifest_path' : '/Data/evaluation_manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths_silence_trimmed.json', + 'audio_dir' : '/Data/VCTK-Corpus-0.92', + 'feature_dir' : '/Data/VCTK-Corpus-0.92', }, - 'libri_unseen_test_shehzeen': { - 'manifest_path' : '/home/shehzeenh/Code/NewT5TTS/manifests/test_clean_withContextAudioPaths.json', + 'libritts_seen': { + 'manifest_path' : '/Data/evaluation_manifests/LibriTTS_seen_evalset_from_testclean_v2.json', 'audio_dir' : '/Data/LibriTTS', 'feature_dir' : '/Data/LibriTTS', }, - 'libri_seen_test_v2': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri_seen_evalset_from_testclean_v2.json', - 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - }, - 'libri_seen_test_v2_shehzeen': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri_seen_evalset_from_testclean_v2.json', + 'libritts_test_clean': { + 'manifest_path' : '/Data/evaluation_manifests/LibriTTS_test_clean_withContextAudioPaths.json', 'audio_dir' : '/Data/LibriTTS', 'feature_dir' : '/Data/LibriTTS', }, - 'libri_unseen_val': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/dev_clean_withContextAudioPaths_evalset.json', - 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - }, - 'spanish_cml_phoneme': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_spanish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', - 'tokenizer_names': ['spanish_phoneme'], - 'whisper_language': 'es' - }, - 'spanish_cml_sep_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_spanish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', - 'tokenizer_names': ['spanish_chartokenizer'], - 'whisper_language': 'es' - }, - 'spanish_cml_shared_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_spanish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', - 'tokenizer_names': ['chartokenizer'], - 'whisper_language': 'es' - }, - 'spanish_cml_sp': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_spanish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_spanish_v0.1', - 'tokenizer_names': ['multilingual_sentencepiece'], - 'whisper_language': 'es' - }, - 'german_cml_phoneme': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_german_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_german_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_german_v0.1', - 'tokenizer_names': ['german_phoneme'], - 'whisper_language': 'de', - 'load_cached_codes_if_available': False - }, - 'german_cml_sep_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_german_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_german_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_german_v0.1', - 'tokenizer_names': ['german_chartokenizer'], - 'whisper_language': 'de', - 'load_cached_codes_if_available': False - }, - 'german_cml_shared_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_german_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_german_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_german_v0.1', - 'tokenizer_names': ['chartokenizer'], - 'whisper_language': 'de', - 'load_cached_codes_if_available': False - }, - 'german_cml_sp': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_german_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_german_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_german_v0.1', - 'tokenizer_names': ['multilingual_sentencepiece'], - 'whisper_language': 'de', - 'load_cached_codes_if_available': False - }, - 'french_cml_sep_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_french_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_french_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_french_v0.1', - 'tokenizer_names': ['french_chartokenizer'], - 'whisper_language': 'fr', - 'load_cached_codes_if_available': False - }, - 'french_cml_shared_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_french_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_french_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_french_v0.1', - 'tokenizer_names': ['chartokenizer'], - 'whisper_language': 'fr', - 'load_cached_codes_if_available': False - }, - 'french_cml_sp': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_french_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_french_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_french_v0.1', - 'tokenizer_names': ['multilingual_sentencepiece'], - 'whisper_language': 'fr', - 'load_cached_codes_if_available': False - }, - 'italian_cml_sep_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_italian_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', - 'tokenizer_names': ['italian_chartokenizer'], - 'whisper_language': 'it', - 'load_cached_codes_if_available': False - }, - 'italian_cml_shared_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_italian_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', - 'tokenizer_names': ['chartokenizer'], - 'whisper_language': 'it', - 'load_cached_codes_if_available': False - }, - 'italian_cml_sp': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_italian_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_italian_v0.1', - 'tokenizer_names': ['multilingual_sentencepiece'], - 'whisper_language': 'it', - 'load_cached_codes_if_available': False - }, - 'dutch_cml_sep_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_dutch_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', - 'tokenizer_names': ['dutch_chartokenizer'], - 'whisper_language': 'nl', - 'load_cached_codes_if_available': False - }, - 'dutch_cml_shared_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_dutch_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', - 'tokenizer_names': ['chartokenizer'], - 'whisper_language': 'nl', - 'load_cached_codes_if_available': False - }, - 'dutch_cml_sp': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_dutch_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_dutch_v0.1', - 'tokenizer_names': ['multilingual_sentencepiece'], - 'whisper_language': 'nl', - 'load_cached_codes_if_available': False - }, - 'portuguese_cml_sep_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_portuguese_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_portuguese_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_portuguese_v0.1', - 'tokenizer_names': ['portuguese_chartokenizer'], - 'whisper_language': 'pt', - 'load_cached_codes_if_available': False - }, - 'portuguese_cml_shared_char_ipa': { - 'manifest_path' : '/Data/CML/manifests_with_codecs_ipa3/cml_tts_dataset_portuguese_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset_ipa.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_portuguese_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_portuguese_v0.1', - 'tokenizer_names': ['chartokenizer'], - 'whisper_language': 'pt', - 'load_cached_codes_if_available': False - }, - 'polish_cml_sep_char': { - 'manifest_path' : '/Data/CML/manifests_with_codecs/cml_tts_dataset_polish_v0.1/test_withAudioCodes_codec21Khz_no_eliz_filtered_100subset.json', - 'audio_dir': '/Data/CML/cml_tts_dataset_polish_v0.1', - 'feature_dir': '/Data/CML/cml_tts_dataset_polish_v0.1', - 'tokenizer_names': ['polish_chartokenizer'], - 'whisper_language': 'pl', - 'load_cached_codes_if_available': False - }, - 'j_libri_unseen_test_no_codes': { - 'manifest_path' : '/home/jasoli/data_prime/manifests/test_clean_withContextAudioPaths.json', - 'audio_dir' : '/mnt/drive1/data/LibriTTS/', - 'feature_dir' : None, - }, - 'hindi_indic_shared_char': { - 'manifest_path' : '/Data/IndicDataset/manifests_ipa/hindi_100_test.json', - 'audio_dir': '/', - 'feature_dir': '/', - 'tokenizer_names': ['chartokenizer'], - 'whisper_language': 'hi', - 'load_cached_codes_if_available': False - }, - 'bengali_indic_shared_char': { - 'manifest_path' : '/Data/IndicDataset/manifests_ipa/bengali_100_test.json', - 'audio_dir': '/', - 'feature_dir': '/', - 'tokenizer_names': ['chartokenizer'], - 'whisper_language': 'bn', - 'load_cached_codes_if_available': False - }, + # We need an4_val_ci just for CI tests 'an4_val_ci': { 'manifest_path' : '/home/TestData/an4_dataset/an4_val_context_v1.json', 'audio_dir' : '/', 'feature_dir' : None, }, - 'j_riva_digits': { - 'manifest_path' : '/home/jasoli/data_prime/manifests/hard-digits-RivaEnContext.ndjson', - 'audio_dir' : '/', - 'feature_dir' : None, - }, - 'j_riva_letters': { - 'manifest_path' : '/home/jasoli/data_prime/manifests/hard-letters-RivaEnContext.ndjson', - 'audio_dir' : '/', - 'feature_dir' : None, - }, - 'j_riva_money': { - 'manifest_path' : '/home/jasoli/data_prime/manifests/hard-digits-RivaEnContext.ndjson', - 'audio_dir' : '/', - 'feature_dir' : None, - }, - 'j_riva_short': { - 'manifest_path' : '/home/jasoli/data_prime/manifests/hard-short-RivaEnContext.ndjson', - 'audio_dir' : '/', - 'feature_dir' : None, - }, } \ No newline at end of file diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index c4132c10093f..d6429d4eafff 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -38,6 +38,9 @@ from nemo.collections.tts.models import MagpieTTSModel from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer +# EVALUATION_DATASETS is the full list of datasets for evaluation of a new model. +EVALUATION_DATASETS = "riva_hard_digits,riva_hard_letters,riva_hard_money,riva_hard_short,vctk,libritts_seen,libritts_test_clean" + def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0.90): metrics = {} for key in metric_keys: @@ -164,6 +167,87 @@ def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: plt.savefig(output_png, format="png", bbox_inches="tight") +def create_combined_violin_plots(dataset_metrics: dict, metric_keys: List[str], output_png: str): + """ + Create box plots comparing multiple datasets for each metric in a single figure. + + Args: + dataset_metrics: Dictionary where keys are dataset names and values are lists of metric dictionaries + metric_keys: List of metric names to plot + output_png: Output file path for the combined plot + """ + # Prepare data for plotting + datasets = list(dataset_metrics.keys()) + num_datasets = len(datasets) + num_metrics = len(metric_keys) + + # Create figure with subplots for each metric + fig, axs = plt.subplots(1, num_metrics, figsize=(num_metrics * 6, 6)) + + # Handle case where there's only one metric (axs won't be an array) + if num_metrics == 1: + axs = [axs] + + # Define colors for different datasets + colors = plt.cm.Set3(np.linspace(0, 1, num_datasets)) + + for metric_idx, metric in enumerate(metric_keys): + ax = axs[metric_idx] + + # Collect data for all datasets for this metric + all_data = [] + positions = [] + dataset_labels = [] + + for dataset_idx, dataset in enumerate(datasets): + df = pd.DataFrame(dataset_metrics[dataset]) + if metric in df.columns: + data = df[metric].dropna() + all_data.append(data) + positions.append(dataset_idx + 1) + dataset_labels.append(dataset) + + # Create box plots + if all_data: + bp = ax.boxplot(all_data, positions=positions, widths=0.6, patch_artist=True, + showmeans=True, meanline=False, meanprops={'marker': 'o', 'markerfacecolor': 'red', + 'markeredgecolor': 'red', 'markersize': 6}) + + # Color the box plots + for i, patch in enumerate(bp['boxes']): + patch.set_facecolor(colors[i]) + patch.set_alpha(0.7) + + # Add mean labels for each dataset + for i, (data, pos) in enumerate(zip(all_data, positions)): + mean = data.mean() + sem = data.sem() + + label_numeric = f"{mean:.3f}±{1.96 * sem:.3f}" + ax.text(pos + 0.1, mean, label_numeric, ha="left", va="center", fontsize=8) + + # Set labels and title + ax.set_title(f"{metric.upper()}", fontsize=12, fontweight='bold') + ax.set_xticks(positions) + ax.set_xticklabels(dataset_labels, rotation=45, ha='right') + ax.grid(True, linestyle="dotted", alpha=0.7) + ax.set_xlabel("Dataset") + ax.set_ylabel(metric) + + # Set y-axis limit for CER metrics + if 'cer' in metric.lower(): + ax.set_ylim(0, 0.3) + + # Add overall title + fig.suptitle("Performance Comparison Across Datasets", fontsize=14, fontweight='bold') + + # Adjust layout and save + plt.tight_layout() + plt.savefig(output_png, format="png", bbox_inches="tight", dpi=300) + plt.close() + print(f"Combined violin plot saved to: {output_png}") + + def run_inference( hparams_file, checkpoint_file, @@ -240,7 +324,6 @@ def run_inference( else: exp_name = "" - checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] checkpoint_name = "{}{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_LT_{}_MGsteps_{}_ST_{}_sched_{}".format( exp_name, checkpoint_name, @@ -262,6 +345,7 @@ def run_inference( dataset_meta_info = evalset_config.dataset_meta_info ssim_per_dataset = [] cer_per_dataset = [] + all_datasets_filewise_metrics = {} # Store filewise metrics for all datasets for combined violin plot for dataset in datasets: print(f"Evaluating dataset {dataset}") metrics_n_repeated = [] @@ -293,6 +377,7 @@ def run_inference( context_duration_min = 5.0 context_duration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though. + dataset_filewise_metrics_all_repeats = [] # Store metrics for all repeats of this dataset for repeat_idx in range(num_repeats): pred_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") os.makedirs(pred_audio_dir, exist_ok=True) @@ -414,6 +499,8 @@ def run_inference( codecmodel_path=codecmodel_path if compute_fcd else None ) metrics_n_repeated.append(metrics) + dataset_filewise_metrics_all_repeats.extend(filewise_metrics) # Collect all filewise metrics for combined plot + with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: json.dump(metrics, f, indent=4) @@ -439,6 +526,9 @@ def run_inference( for codes_file in codec_file_paths: os.remove(codes_file) + # Store filewise metrics for this dataset for combined plotting + all_datasets_filewise_metrics[dataset] = dataset_filewise_metrics_all_repeats + metric_keys = ['cer_filewise_avg', 'wer_filewise_avg', 'cer_cumulative', 'wer_cumulative', 'ssim_pred_gt_avg', 'ssim_pred_context_avg', 'ssim_gt_context_avg', 'ssim_pred_gt_avg_alternate', 'ssim_pred_context_avg_alternate', 'ssim_gt_context_avg_alternate', @@ -471,6 +561,11 @@ def run_inference( cer_current = np.mean(measurements) cer_per_dataset.append(cer_current) + # Create combined violin plot for all datasets + if len(all_datasets_filewise_metrics) > 1: # Only create combined plot if we have multiple datasets + combined_output_png = os.path.join(out_dir, f"{checkpoint_name}_combined_violin_plot.png") + create_combined_violin_plots(all_datasets_filewise_metrics, violin_plot_metrics, combined_output_png) + # Average across datasets ssim = np.mean(ssim_per_dataset) cer = np.mean(cer_per_dataset) @@ -481,17 +576,13 @@ def run_inference( def main(): parser = argparse.ArgumentParser(description='Experiment Evaluation') - parser.add_argument('--hparams_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_hparams.yaml,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_hparams.yaml") + parser.add_argument('--hparams_files', type=str, default=None) parser.add_argument('--hparams_file_from_wandb', action='store_true') - parser.add_argument('--checkpoint_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_epoch302.ckpt,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_epoch305.ckpt") + parser.add_argument('--checkpoint_files', type=str, default=None) parser.add_argument('--nemo_files', type=str, default=None) - parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo", help="Path to codec model (used for FCD computation unless --disable_fcd is specified)") - parser.add_argument('--datasets', type=str, default="libri_unseen_test_12.5") - parser.add_argument('--base_exp_dir', type=str, default="/datap/misc/eosmountedresson/") - parser.add_argument('--draco_exp_dir', type=str, default="/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/gitrepos/experiments/NewT5TTS_FixedPosEmb/AllKernselSize3/EdressonCodecExperiments/") - parser.add_argument('--server_address', type=str, default="pneekhara@login-eos02.eos.clusters.nvidia.com") - parser.add_argument('--exp_names', type=str, default="koel_12.5_FPS_causal_13codebooks_codecmodel_context5sec_LTN1,koel_12.5_FPS_causal_13codebooks_codecmodel_context5sec_LTN3") - parser.add_argument('--local_ckpt_dir', type=str, default="/datap/misc/experiment_checkpoints/localtransformer") + parser.add_argument('--codecmodel_path', type=str, default=None, help="Path to codec model") + parser.add_argument('--datasets', type=str, default=None) + # Parameters for running inference experiments locally parser.add_argument('--out_dir', type=str, default="/datap/misc/Evals/LocalTransformerAblations2") parser.add_argument('--temperature', type=float, default=0.6) parser.add_argument('--use_cfg', action='store_true') @@ -499,13 +590,14 @@ def main(): parser.add_argument('--maskgit_n_steps', type=int, default=3) parser.add_argument('--cfg_scale', type=float, default=2.5) parser.add_argument('--apply_attention_prior', action='store_true') - parser.add_argument('--attention_prior_epsilon', type=float, default=1e-3) - parser.add_argument('--attention_prior_lookahead_window', type=int, default=10) + parser.add_argument('--attention_prior_epsilon', type=float, default=0.1) + parser.add_argument('--attention_prior_lookahead_window', type=int, default=5) parser.add_argument('--estimate_alignment_from_layers', type=str, default=None) parser.add_argument('--apply_prior_to_layers', type=str, default=None) - parser.add_argument('--start_prior_after_n_audio_steps', type=int, default=10) + parser.add_argument('--start_prior_after_n_audio_steps', type=int, default=0) parser.add_argument('--topk', type=int, default=80) - parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--batch_size', type=int, default=32) + # Parameters for evaluation parser.add_argument('--sv_model', type=str, default="titanet") # titanet, wavlm parser.add_argument('--asr_model_name', type=str, default="nvidia/parakeet-tdt-1.1b") # stt_en_conformer_transducer_large, nvidia/parakeet-ctc-0.6b parser.add_argument('--num_repeats', type=int, default=1) @@ -519,6 +611,9 @@ def main(): parser.add_argument('--violin_plot_metrics', type=str, nargs='*', default=['cer','pred_context_ssim'], help="Which metrics to add the violin plot.") args = parser.parse_args() + if args.datasets is None: + args.datasets = EVALUATION_DATASETS + # FCD computation is enabled by default, disabled only when --disable_fcd is specified compute_fcd = not args.disable_fcd @@ -582,55 +677,11 @@ def main(): checkpoint_file=None, nemo_file=nemo_file, ) - # Mode 3: Discover and run experiments from a base directory - # Mount DRACO_EXP_DIR to BASE_EXP_DIR as follows: - # sshfs -o allow_other pneekhara@draco-oci-dc-02.draco-oci-iad.nvidia.com:/lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/experiments/NewT5AllFixedFresh /datap/misc/dracomount/ - elif args.base_exp_dir: - BASE_EXP_DIR = args.base_exp_dir - DRACO_EXP_DIR = args.draco_exp_dir - if args.exp_names is None: - exp_names = os.listdir(BASE_EXP_DIR) - else: - exp_names = args.exp_names.split(",") - - for exp_name in exp_names: - exp_dir = os.path.join(BASE_EXP_DIR, exp_name) - # recurisvely look for hparams.yaml - try: - hparams_file = glob.glob(f"{exp_dir}/**/hparams.yaml", recursive=True)[0] - checkpoints_dir = glob.glob(f"{exp_dir}/**/checkpoints", recursive=True)[0] - last_checkpoint = (glob.glob(f"{checkpoints_dir}/*last.ckpt"))[0] - except: - print(f"Skipping experiment {exp_name} as hparams or last checkpoint not found.") - continue - last_checkpoint_path_draco = last_checkpoint.replace(BASE_EXP_DIR, DRACO_EXP_DIR) - epoch_num = last_checkpoint.split("epoch=")[1].split("-")[0] - - checkpoint_copy_path = os.path.join(args.local_ckpt_dir, f"{exp_name}_epoch_{epoch_num}.ckpt") - hparams_copy_path = os.path.join(args.local_ckpt_dir, f"{exp_name}_hparams.yaml") - - scp_command = f"scp {args.server_address}:{last_checkpoint_path_draco} {checkpoint_copy_path}" - print(f"Running command: {scp_command}") - os.system(scp_command) - print("Copied checkpoint.") - hparams_path_draco = hparams_file.replace(BASE_EXP_DIR, DRACO_EXP_DIR) - scp_command_hparams = f"scp {args.server_address}:{hparams_path_draco} {hparams_copy_path}" - print(f"Running command: {scp_command_hparams}") - os.system(scp_command_hparams) - print("Copied hparams file.") - print("Hparams file path: ", hparams_copy_path) - print("Checkpoint file path: ", checkpoint_copy_path) - run_inference_w_args( - hparams_copy_path, - checkpoint_copy_path, - nemo_file=None, - ) else: parser.error( "You must provide a model to run. Please specify either:\n" "1. --hparams_files and --checkpoint_files\n" "2. --nemo_file\n" - "3. --base_exp_dir to discover experiments" ) if args.cer_target is not None and cer > float(args.cer_target): raise ValueError() @@ -639,4 +690,4 @@ def main(): if __name__ == '__main__': - main() + main() \ No newline at end of file From a307a3da02eee3c420af5b76055ecbe75dabb1c6 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 19 Aug 2025 12:47:03 -0400 Subject: [PATCH 071/113] Fix path in config_evalset.py (#14509) Fixed to the actual name of the file on EOS Signed-off-by: Fejgin, Roy --- scripts/magpietts/evalset_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index a36ec44db20c..4add09cafc78 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -43,7 +43,7 @@ 'feature_dir' : '/Data/LibriTTS', }, 'libritts_test_clean': { - 'manifest_path' : '/Data/evaluation_manifests/LibriTTS_test_clean_withContextAudioPaths.json', + 'manifest_path' : '/Data/evaluation_manifests/LibriTTS_test_clean_withContextAudioPaths.jsonl', 'audio_dir' : '/Data/LibriTTS', 'feature_dir' : '/Data/LibriTTS', }, From 2aa2c6972817abd14baeed7bb063afd1adb53dc5 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Tue, 19 Aug 2025 12:48:29 -0400 Subject: [PATCH 072/113] =?UTF-8?q?Add=20an=20option=20to=20seed=20the=20r?= =?UTF-8?q?andom=20number=20generators=20for=20debugging=20/=20re=E2=80=A6?= =?UTF-8?q?=20(#14510)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add an option to seed the random number generators for debugging / reproducibility. Signed-off-by: Fejgin, Roy * Cleanup Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- examples/tts/magpietts.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index 5681600f8617..8c74111ebef2 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -47,6 +47,12 @@ def main(cfg): trainer.callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='step', log_weight_decay=True)) exp_manager(trainer, cfg.get("exp_manager", None)) + seed = cfg.get('seed', None) + if seed is not None: + # Option to seed for debugging + logging.info(f"Setting seed to {seed}") + pl.seed_everything(seed, workers=True) + mode = cfg.get('mode', 'train') if mode == 'train': model = MagpieTTSModel(cfg=cfg.model, trainer=trainer) From bdb41dfce742f618a875667c102e50447079de70 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Wed, 20 Aug 2025 12:58:06 -0400 Subject: [PATCH 073/113] Avoid speaker embedding crash on short signals (#14525) * Avoid speaker embedding crash on short signals The alternate speaker embedding model was not using the extract_embedding() method, which takes care of padding very short signals (which can sometimes be generated during evaluation) so that they don't crash the embedding model. This caused crashes during evaluation. Fixed it to use extract_embedding() like the default speaker embedding model does. Signed-off-by: Fejgin, Roy * Comments Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- scripts/magpietts/evaluate_generated_audio.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 3553ac98fee8..67e22b4d3181 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -18,6 +18,7 @@ import string import logging from contextlib import contextmanager +from functools import partial import numpy as np import torch @@ -252,22 +253,31 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo pred_context_ssim = 0.0 gt_context_ssim = 0.0 with torch.no_grad(): - gt_speaker_embedding = extract_embedding(speaker_verification_model, feature_extractor, gt_audio_filepath, device, sv_model_type) - pred_speaker_embedding = extract_embedding(speaker_verification_model, feature_extractor, pred_audio_filepath, device, sv_model_type) + extract_embedding_fn = partial(extract_embedding, model=speaker_verification_model, extractor=feature_extractor, device=device, sv_model_type=sv_model_type) + extract_embedding_fn_alternate = partial(extract_embedding, model=speaker_verification_model_alternate, extractor=feature_extractor, device=device, sv_model_type=sv_model_type) + + # Ground truth vs. predicted + gt_speaker_embedding = extract_embedding_fn(audio_path=gt_audio_filepath) + pred_speaker_embedding = extract_embedding_fn(audio_path=pred_audio_filepath) pred_gt_ssim = torch.nn.functional.cosine_similarity(gt_speaker_embedding, pred_speaker_embedding, dim=0).item() - gt_speaker_embedding_alternate = speaker_verification_model_alternate.get_embedding(gt_audio_filepath).squeeze() - pred_speaker_embedding_alternate = speaker_verification_model_alternate.get_embedding(pred_audio_filepath).squeeze() + # Ground truth vs. predicted (alternate model) + gt_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=gt_audio_filepath) + pred_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=pred_audio_filepath) pred_gt_ssim_alternate = torch.nn.functional.cosine_similarity(gt_speaker_embedding_alternate, pred_speaker_embedding_alternate, dim=0).item() if context_audio_filepath is not None: - context_speaker_embedding = extract_embedding(speaker_verification_model, feature_extractor, context_audio_filepath, device, sv_model_type) - context_speaker_embedding_alternate = speaker_verification_model_alternate.get_embedding(context_audio_filepath).squeeze() - + context_speaker_embedding = extract_embedding_fn(audio_path=context_audio_filepath) + context_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=context_audio_filepath) + + # Predicted vs. context pred_context_ssim = torch.nn.functional.cosine_similarity(pred_speaker_embedding, context_speaker_embedding, dim=0).item() + # Ground truth vs. context gt_context_ssim = torch.nn.functional.cosine_similarity(gt_speaker_embedding, context_speaker_embedding, dim=0).item() + # Predicted vs. context (alternate model) pred_context_ssim_alternate = torch.nn.functional.cosine_similarity(pred_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() + # Ground truth vs. context (alternate model) gt_context_ssim_alternate = torch.nn.functional.cosine_similarity(gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() filewise_metrics.append({ From b3c253f0d627286d46d785f23b0a57d0b50adb0b Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Wed, 20 Aug 2025 17:09:05 -0400 Subject: [PATCH 074/113] Inference: Fix for speaker embedding of short files (#14533) * Inference: Fix for speaker embedding of short files The padded signal wasn't actually being given to the embedding model. Needed to write it to a temporary file first. Signed-off-by: Fejgin, Roy * Use torch.inference_mode() instead of torch.no_grad() It's a bit more efficient. Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- scripts/magpietts/evaluate_generated_audio.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 67e22b4d3181..9e2bf08e5544 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -19,6 +19,8 @@ import logging from contextlib import contextmanager from functools import partial +import soundfile as sf +import tempfile import numpy as np import torch @@ -81,7 +83,7 @@ def transcribe_with_whisper(whisper_model, whisper_processor, audio_path, langua inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features inputs = inputs.to(device) # Generate transcription - with torch.no_grad(): + with torch.inference_mode(): predicted_ids = whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) # Decode transcription @@ -129,11 +131,14 @@ def extract_embedding(model, extractor, audio_path, device, sv_model_type): speech_array = pad_audio_to_min_length(speech_array, int(sampling_rate), min_seconds=0.5) if sv_model_type == "wavlm": inputs = extractor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_values.to(device) - with torch.no_grad(): + with torch.inference_mode(): embeddings = model(inputs).embeddings else: # Titanet - with torch.no_grad(): - embeddings = model.get_embedding(audio_path).squeeze() + with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: + # the embedding model doesn't accept NumPy arrays, so we write to a temporary file + sf.write(temp_file.name, speech_array, samplerate=16000) + with torch.inference_mode(): + embeddings = model.get_embedding(temp_file.name).squeeze() return embeddings.squeeze() @@ -210,7 +215,7 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo try: if language == "en": - with torch.no_grad(): + with torch.inference_mode(): pred_text = asr_model.transcribe([pred_audio_filepath])[0].text pred_text = process_text(pred_text) gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0].text @@ -252,7 +257,7 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo pred_context_ssim = 0.0 gt_context_ssim = 0.0 - with torch.no_grad(): + with torch.inference_mode(): extract_embedding_fn = partial(extract_embedding, model=speaker_verification_model, extractor=feature_extractor, device=device, sv_model_type=sv_model_type) extract_embedding_fn_alternate = partial(extract_embedding, model=speaker_verification_model_alternate, extractor=feature_extractor, device=device, sv_model_type=sv_model_type) From f78d0bf1e84a898c1f1cfc39e5b7a5f385934fdb Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Thu, 21 Aug 2025 20:48:54 +0530 Subject: [PATCH 075/113] [magpietts] custom tokenizer for text context (#14408) * make text conditioning tokenizer customizable Signed-off-by: Paarth Neekhara * backward compatibility bug fix Signed-off-by: Paarth Neekhara * add none option for custom tokenizer and default to it Signed-off-by: Paarth Neekhara * change end detection hard coded param - allow EOS prediction earlier Signed-off-by: Paarth Neekhara * bug fixes Signed-off-by: Paarth Neekhara * use google-t5/t5-small text_conditioning_tokenizer_name as default for backward compatibility Signed-off-by: Paarth Neekhara * code cleanup Signed-off-by: Paarth Neekhara * eval set config bug fix Signed-off-by: Paarth Neekhara * Refactor text conditioning logic Signed-off-by: Paarth Neekhara * clean up Signed-off-by: Paarth Neekhara * cleanup Signed-off-by: Paarth Neekhara * cleanup Signed-off-by: Paarth Neekhara * bug fix Signed-off-by: Paarth Neekhara * Update nemo/collections/tts/models/magpietts.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Paarth Neekhara * Update nemo/collections/tts/data/text_to_speech_dataset.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Paarth Neekhara * Update nemo/collections/tts/models/magpietts.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Paarth Neekhara * typo fix Signed-off-by: Paarth Neekhara * bug fixes Signed-off-by: Paarth Neekhara * address comments Signed-off-by: Paarth Neekhara --------- Signed-off-by: Paarth Neekhara Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../text_to_speech/tts_tokenizers.py | 25 +++++--- .../tts/data/text_to_speech_dataset.py | 11 ++-- .../tts/data/text_to_speech_dataset_lhotse.py | 30 ++++----- nemo/collections/tts/models/magpietts.py | 64 +++++++++++++++---- scripts/magpietts/infer_and_evaluate.py | 13 ++-- ...L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh | 1 + ...TS_InferEvaluate_Magpietts_SeenSpeakers.sh | 1 + ...L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh | 1 + 8 files changed, 97 insertions(+), 49 deletions(-) diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index 3d25289a2007..e6d0896beabe 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -1098,26 +1098,35 @@ def __init__(self, tokenizers: List[Union[BaseTokenizer, PreTrainedTokenizerBase """ assert len(tokenizers) == len(tokenizer_names), "Number of tokenizers and tokenizer names must be the same." tokens = [] - toknizer_offsets = {} + tokenizer_offsets = {} tokenizer_offset = 0 self.tokenizers = {} + num_tokens_per_tokenizer = {} + tokenizer_pad_ids = {} for idx, tokenizer in enumerate(tokenizers): - self.tokenizers[tokenizer_names[idx]] = tokenizer - toknizer_offsets[tokenizer_names[idx]] = tokenizer_offset + tokenizer_name = tokenizer_names[idx] + self.tokenizers[tokenizer_name] = tokenizer + tokenizer_offsets[tokenizer_name] = tokenizer_offset if isinstance(tokenizer, BaseTokenizer): tokens.extend(tokenizer.tokens) num_tokens = len(tokenizer.tokens) + tokenizer_pad_ids[tokenizer_name] = tokenizer.pad + tokenizer_offset elif isinstance(tokenizer, PreTrainedTokenizerBase): _tokens = list(tokenizer.get_vocab().keys()) tokens.extend(_tokens) num_tokens = len(_tokens) + tokenizer_pad_ids[tokenizer_name] = tokenizer.pad_token_id + tokenizer_offset else: raise ValueError("Tokenizers must be either BaseTokenizer or HuggingFace PreTrainedTokenizerBase.") tokenizer_offset += num_tokens + num_tokens_per_tokenizer[tokenizer_name] = num_tokens self.tokens = tokens self.tokenizer_names = tokenizer_names - self.toknizer_offsets = toknizer_offsets + self.tokenizer_offsets = tokenizer_offsets + self.vocab_size = len(tokens) + self.num_tokens_per_tokenizer = num_tokens_per_tokenizer + self.tokenizer_pad_ids = tokenizer_pad_ids # Define aggregated token's pad value from the first tokenizer's pad value first_tokenizer = self.tokenizers[tokenizer_names[0]] if hasattr(first_tokenizer, "pad_token_id"): # Defined in PreTrainedTokenizerBase subclasses @@ -1127,12 +1136,12 @@ def __init__(self, tokenizers: List[Union[BaseTokenizer, PreTrainedTokenizerBase else: raise ValueError("AggregatedTTSTokenizer could not find a padding token in the first tokenizer") - def encode(self, text: str, tokenizer_name: str) -> List[int]: + def encode(self, text: str, tokenizer_name: str = None) -> List[int]: tokenizer = self.tokenizers[tokenizer_name] tokens = tokenizer.encode(text) - return [self.toknizer_offsets[tokenizer_name] + token for token in tokens] + return [self.tokenizer_offsets[tokenizer_name] + token for token in tokens] - def decode(self, tokens: List[int], tokenizer_name: str) -> str: + def decode(self, tokens: List[int], tokenizer_name: str = None) -> str: tokenizer = self.tokenizers[tokenizer_name] - return tokenizer.decode([token - self.toknizer_offsets[tokenizer_name] for token in tokens]) + return tokenizer.decode([token - self.tokenizer_offsets[tokenizer_name] for token in tokens]) diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index ad2f9297e470..4d680a56912d 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -379,6 +379,7 @@ def __init__( tokenizer_config=None, load_16khz_audio: bool = True, use_text_conditioning_tokenizer: bool = False, + text_conditioning_tokenizer_name: str = None, pad_context_text_to_max_duration: bool = False, context_duration_min: float = 3.0, context_duration_max: float = 10.0, @@ -412,9 +413,7 @@ def __init__( self.text_tokenizer = None # Assigned in worker_init_fn in model file self.load_16khz_audio = load_16khz_audio self.use_text_conditioning_tokenizer = use_text_conditioning_tokenizer - self.text_conditioning_tokenizer = ( - None # Assigned in worker_init_fn in model file if use_text_conditioning_tokenizer is True - ) + self.text_conditioning_tokenizer_name = text_conditioning_tokenizer_name self.pad_context_text_to_max_duration = pad_context_text_to_max_duration self.context_duration_min = context_duration_min self.context_duration_max = context_duration_max @@ -579,17 +578,17 @@ def __getitem__(self, index): if self.use_text_conditioning_tokenizer: if 'context_text' in data.manifest_entry: - context_tokens = self.text_conditioning_tokenizer(data.manifest_entry['context_text'])['input_ids'] + context_tokens = self.text_tokenizer.encode(data.manifest_entry['context_text'], self.text_conditioning_tokenizer_name) example['has_text_context'] = True else: - context_tokens = self.text_conditioning_tokenizer("[NO TEXT CONTEXT]")['input_ids'] + context_tokens = self.text_tokenizer.encode("[NO TEXT CONTEXT]", self.text_conditioning_tokenizer_name) example['has_text_context'] = False if self.pad_context_text_to_max_duration: _required_len = ( int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 ) # +2 for BOS and EOS if len(context_tokens) < _required_len: - _pad_id = self.text_conditioning_tokenizer.pad_token_id + _pad_id = self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name] context_tokens += [_pad_id] * (_required_len - len(context_tokens)) else: context_tokens = context_tokens[:_required_len] diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index b43f8ed032e7..31e8827165c1 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -33,7 +33,7 @@ from nemo.utils import logging -def setup_tokenizers(all_tokenizers_config, use_text_conditioning_tokenizer, mode='train'): +def setup_tokenizers(all_tokenizers_config, mode='train'): # Being used in both model and worker_init_fn, so it is defined here # Returns two tokenizers: one for TTS transcript and one for conditioning text (if needed) tokenizers = [] @@ -42,6 +42,8 @@ def setup_tokenizers(all_tokenizers_config, use_text_conditioning_tokenizer, mod tokenizer_config = all_tokenizers_config[tokenizer_name] if tokenizer_config._target_ == 'AutoTokenizer': tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.pretrained_model) + elif tokenizer_config._target_ == 'T5Tokenizer': + tokenizer = T5Tokenizer.from_pretrained(tokenizer_config.pretrained_model) else: text_tokenizer_kwargs = {} if "g2p" in tokenizer_config: @@ -52,16 +54,10 @@ def setup_tokenizers(all_tokenizers_config, use_text_conditioning_tokenizer, mod tokenizer.set_phone_prob(1.0) tokenizers.append(tokenizer) tokenizer_names.append(tokenizer_name) - + aggregated_tokenizer = AggregatedTTSTokenizer(tokenizers, tokenizer_names) # TTS Transcript tokenizer - text_conditioning_tokenizer = None - - if use_text_conditioning_tokenizer: - # TODO: make this configurable - # Conditioning text tokenizer - text_conditioning_tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") - - return aggregated_tokenizer, text_conditioning_tokenizer + + return aggregated_tokenizer def check_speaker_format(item: str): @@ -149,6 +145,7 @@ def __init__( context_duration_min: float = 3.0, context_duration_max: float = 10.0, use_text_conditioning_tokenizer: bool = False, + text_conditioning_tokenizer_name: str = None, tokenizer_config: DictConfig = None, ): super().__init__() @@ -168,12 +165,12 @@ def __init__( self.dataset_type = dataset_type # 'train' or 'test' self.load_16khz_audio = load_16khz_audio self.use_text_conditioning_tokenizer = use_text_conditioning_tokenizer + self.text_conditioning_tokenizer_name = text_conditioning_tokenizer_name self.pad_context_text_to_max_duration = pad_context_text_to_max_duration self.context_duration_min = context_duration_min self.context_duration_max = context_duration_max self.tokenizer_config = tokenizer_config self.text_tokenizer = None - self.text_conditioning_tokenizer = None def get_num_audio_samples_to_slice(self, duration, sample_rate): num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) @@ -192,9 +189,8 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: worker_info = torch.utils.data.get_worker_info() worker_id = worker_info.id if worker_info is not None else 0 logging.info(f"Worker {worker_id} initializing tokenizers...") - self.text_tokenizer, self.text_conditioning_tokenizer = setup_tokenizers( + self.text_tokenizer = setup_tokenizers( all_tokenizers_config=self.tokenizer_config, - use_text_conditioning_tokenizer=self.use_text_conditioning_tokenizer, mode=self.dataset_type, ) self.bos_id = len(self.text_tokenizer.tokens) @@ -368,17 +364,17 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: if self.use_text_conditioning_tokenizer: if cut.supervisions[0].has_custom("context_text"): - context_text_tokens = self.text_conditioning_tokenizer(cut.supervisions[0].context_text)['input_ids'] + context_text_tokens = self.text_tokenizer.encode(cut.supervisions[0].context_text, tokenizer_name=self.text_conditioning_tokenizer_name) has_text_context = True else: - context_text_tokens = self.text_conditioning_tokenizer("[NO TEXT CONTEXT]")['input_ids'] + context_text_tokens = self.text_tokenizer.encode("[NO TEXT CONTEXT]", tokenizer_name=self.text_conditioning_tokenizer_name) has_text_context = False if self.pad_context_text_to_max_duration: _required_len = ( int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 ) # +2 for BOS and EOS if len(context_text_tokens) < _required_len: - _pad_id = self.text_conditioning_tokenizer.pad_token_id + _pad_id = self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name] context_text_tokens += [_pad_id] * (_required_len - len(context_text_tokens)) else: # TODO @xueyang: It seems counter intuition if trimming the text context tokens to the required @@ -454,7 +450,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: if self.use_text_conditioning_tokenizer: batch_dict['context_text_tokens'] = collate_vectors( - tensors=context_text_tokens_list, padding_value=self.text_conditioning_tokenizer.pad_token_id + tensors=context_text_tokens_list, padding_value=self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name] ) batch_dict['context_text_tokens_lens'] = torch.IntTensor(context_text_tokens_len_list) batch_dict['has_text_context'] = torch.BoolTensor(context_has_text_context_list) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index b96225daa025..a692abad3586 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -48,11 +48,11 @@ def worker_init_fn(worker_id): logging.info(f"Worker {worker_id} initializing...") worker_info = get_worker_info() dataset = worker_info.dataset # Get the dataset instance in this worker - tokenizer, text_conditioning_tokenizer = setup_tokenizers( - dataset.tokenizer_config, dataset.use_text_conditioning_tokenizer, mode=dataset.dataset_type + tokenizer = setup_tokenizers( + dataset.tokenizer_config, + mode=dataset.dataset_type ) dataset.text_tokenizer = tokenizer - dataset.text_conditioning_tokenizer = text_conditioning_tokenizer class MagpieTTSModel(ModelPT): """ @@ -112,16 +112,41 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): del cfg['text_tokenizer'] self.use_text_conditioning_encoder = cfg.get('use_text_conditioning_encoder', False) + # Using google-t5/t5-small as default text conditioning tokenizer for backward compatibility. + self.text_conditioning_tokenizer_name = cfg.get('text_conditioning_tokenizer_name', None) + self.legacy_text_conditioning = cfg.get('legacy_text_conditioning', False) + + if self.legacy_text_conditioning: + if self.text_conditioning_tokenizer_name is None: + self.text_conditioning_tokenizer_name = "google-t5/t5-small" + + tokenizer_target = "AutoTokenizer" + if self.text_conditioning_tokenizer_name == "google-t5/t5-small": + tokenizer_target = "T5Tokenizer" + + with open_dict(cfg): + cfg.text_tokenizers[self.text_conditioning_tokenizer_name] = { + '_target_' : tokenizer_target, + 'pretrained_model' : self.text_conditioning_tokenizer_name, + } + elif self.text_conditioning_tokenizer_name is None: + # If no text_conditioning_tokenizer_name is specified, use the first one as default + # For text context tokenization + self.text_conditioning_tokenizer_name = list(cfg.text_tokenizers.keys())[0] + # TODO @xueyang: both tokenizers are only used to get some token ids. We # should kill them to save a small amount of mem resources since dataloader will initialize them # again after the worker processes are spawned. - self.tokenizer, self.text_conditioning_tokenizer = setup_tokenizers( + self.tokenizer = setup_tokenizers( all_tokenizers_config=cfg.text_tokenizers, - use_text_conditioning_tokenizer=self.use_text_conditioning_encoder, mode='train', ) num_tokens_tokenizer = len(self.tokenizer.tokens) + if self.legacy_text_conditioning: + # Text context tokens are not a part of the the regular transcript embedding table in legacy models + num_tokens_tokenizer -= self.tokenizer.num_tokens_per_tokenizer[self.text_conditioning_tokenizer_name] + num_tokens = num_tokens_tokenizer + 2 # +2 for BOS and EOS self.bos_id = num_tokens - 2 self.eos_id = num_tokens - 1 @@ -133,8 +158,9 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg=cfg, trainer=trainer) - if self.use_text_conditioning_encoder: - self.context_text_embedding = nn.Embedding(self.text_conditioning_tokenizer.vocab_size, cfg.embedding_dim) + if self.legacy_text_conditioning: + tc_tokenizer = self.tokenizer.tokenizers[self.text_conditioning_tokenizer_name] + self.context_text_embedding = nn.Embedding(tc_tokenizer.vocab_size, cfg.embedding_dim) # This needs to happen after super().__init__() self._codec_model = codec_model @@ -888,6 +914,15 @@ def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_co ) return alignment_loss + def embed_context_text(self, context_text_tokens): + if self.legacy_text_conditioning: + context_text_tokens = context_text_tokens - self.tokenizer.tokenizer_offsets[self.text_conditioning_tokenizer_name] + context_text_embedded = self.context_text_embedding(context_text_tokens) # (B, L, E) + else: + context_text_embedded = self.text_embedding(context_text_tokens) # (B, L, E) + + return context_text_embedded + def prepare_context_tensors(self, batch): dec_context_size = 0 additional_decoder_input = None @@ -935,7 +970,8 @@ def prepare_context_tensors(self, batch): if self.use_text_conditioning_encoder: context_text_tokens = batch['context_text_tokens'] context_text_lens = batch['context_text_tokens_lens'] - context_text_embedded = self.context_text_embedding(context_text_tokens) # (B, L, E) + context_text_embedded = self.embed_context_text(context_text_tokens) # (B, L, E) + # Pad context_audio_embedded or context_text_embedded so that they have same number of timesteps if context_audio_embedded.size(1) < context_text_embedded.size(1): padding = torch.zeros( @@ -1506,8 +1542,8 @@ def construct_inference_prior(self, prior_epsilon, cross_attention_scores, for bidx in finished_texts_counter: finished_texts_counter[bidx] += 1 - if finished_texts_counter[bidx] > 10: - # This means we have been within the text EOS window for atleast 10 timesteps + if finished_texts_counter[bidx] > 5: + # This means we have been within the text EOS window for at least 5 timesteps # We should allow EOS to be predicted now. unfinished_texts[bidx] = False @@ -1871,6 +1907,7 @@ def get_dataset(self, dataset_cfg, dataset_type): load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, dataset_type=dataset_type, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder, + text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, context_duration_min=self.cfg.context_duration_min, context_duration_max=self.cfg.context_duration_max, @@ -1901,6 +1938,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D context_duration_min=self.cfg.context_duration_min, context_duration_max=self.cfg.context_duration_max, use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder, + text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, tokenizer_config=self.cfg.text_tokenizers, ) data_loader = get_lhotse_dataloader_from_config( @@ -1931,9 +1969,8 @@ def setup_training_data(self, dataset_cfg): if dataset_cfg.dataloader_params.num_workers == 0: persistent_workers = False # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) - dataset.text_tokenizer, dataset.text_conditioning_tokenizer = setup_tokenizers( + dataset.text_tokenizer = setup_tokenizers( all_tokenizers_config=self.cfg.text_tokenizers, - use_text_conditioning_tokenizer=self.use_text_conditioning_encoder, mode='train', ) self._train_dl = torch.utils.data.DataLoader( @@ -1960,9 +1997,8 @@ def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: if dataset_cfg.dataloader_params.num_workers == 0: persistent_workers = False # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) - dataset.text_tokenizer, dataset.text_conditioning_tokenizer = setup_tokenizers( + dataset.text_tokenizer = setup_tokenizers( all_tokenizers_config=self.cfg.text_tokenizers, - use_text_conditioning_tokenizer=self.use_text_conditioning_encoder, mode='test' ) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index d6429d4eafff..3a492ac520a0 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -53,7 +53,7 @@ def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0 metrics[key] = "{:.4f} +/- {:.4f}".format(mean, confidence_interval) return metrics -def update_config(model_cfg, codecmodel_path, legacy_codebooks=False): +def update_config(model_cfg, codecmodel_path, legacy_codebooks=False, legacy_text_conditioning=False): ''' helper function to rename older yamls from t5 to magpie ''' model_cfg.codecmodel_path = codecmodel_path if hasattr(model_cfg, 'text_tokenizer'): @@ -63,6 +63,7 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False): model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0 model_cfg.train_ds = None model_cfg.validation_ds = None + model_cfg.legacy_text_conditioning = legacy_text_conditioning if "t5_encoder" in model_cfg: model_cfg.encoder = model_cfg.t5_encoder del model_cfg.t5_encoder @@ -273,6 +274,7 @@ def run_inference( use_local_transformer=False, maskgit_n_steps=3, legacy_codebooks=False, + legacy_text_conditioning=False, clean_up_disk=False, hparams_file_from_wandb=False, log_exp_name=False, @@ -289,7 +291,7 @@ def run_inference( model_cfg = model_cfg.value with open_dict(model_cfg): - model_cfg, cfg_sample_rate = update_config(model_cfg, codecmodel_path, legacy_codebooks) + model_cfg, cfg_sample_rate = update_config(model_cfg, codecmodel_path, legacy_codebooks, legacy_text_conditioning) model = MagpieTTSModel(cfg=model_cfg) model.use_kv_cache_for_inference = True @@ -303,7 +305,7 @@ def run_inference( elif nemo_file is not None: model_cfg = MagpieTTSModel.restore_from(nemo_file, return_config=True) with open_dict(model_cfg): - model_cfg, cfg_sample_rate = update_config(model_cfg, codecmodel_path, legacy_codebooks) + model_cfg, cfg_sample_rate = update_config(model_cfg, codecmodel_path, legacy_codebooks, legacy_text_conditioning) model = MagpieTTSModel.restore_from(nemo_file, override_config_path=model_cfg) model.use_kv_cache_for_inference = True checkpoint_name = nemo_file.split("/")[-1].split(".nemo")[0] @@ -402,6 +404,7 @@ def run_inference( tokenizer_config=None, load_16khz_audio=model.model_type == 'single_encoder_sv_tts', use_text_conditioning_tokenizer=model.use_text_conditioning_encoder, + text_conditioning_tokenizer_name=model.text_conditioning_tokenizer_name, pad_context_text_to_max_duration=model.pad_context_text_to_max_duration, context_duration_min=context_duration_min, context_duration_max=context_duration_max, @@ -417,7 +420,7 @@ def run_inference( g2p = model.tokenizer.g2p if g2p is not None: g2p.phoneme_probability = 1.0 - test_dataset.text_conditioning_tokenizer = model.text_conditioning_tokenizer + test_data_loader = torch.utils.data.DataLoader( test_dataset, @@ -603,6 +606,7 @@ def main(): parser.add_argument('--num_repeats', type=int, default=1) parser.add_argument('--confidence_level', type=float, default=0.95) parser.add_argument('--legacy_codebooks', action='store_true') + parser.add_argument('--legacy_text_conditioning', action='store_true') parser.add_argument('--clean_up_disk', action='store_true') parser.add_argument('--cer_target', type=float, default=None) parser.add_argument('--ssim_target', type=float, default=None) @@ -647,6 +651,7 @@ def main(): use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, legacy_codebooks=args.legacy_codebooks, + legacy_text_conditioning=args.legacy_text_conditioning, clean_up_disk=args.clean_up_disk, hparams_file_from_wandb=args.hparams_file_from_wandb, log_exp_name=args.log_exp_name, diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh index a1d3ecd4b405..622d8b978df5 100644 --- a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh @@ -65,6 +65,7 @@ coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/ne model.context_encoder.n_layers=6 \ model.encoder.is_causal=false \ model.use_text_conditioning_encoder=true \ + +model.legacy_text_conditioning=True \ +model.forced_num_all_tokens_per_codebook=2048 \ +model.forced_audio_eos_id=2047 \ +model.forced_audio_bos_id=2046 \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh index 6a91252decfd..d47e112ad2bb 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh @@ -23,6 +23,7 @@ coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/ne --hparams_files /home/TestData/tts/2506_SeenSpeaker/hparams.yaml \ --checkpoint_files /home/TestData/tts/2506_SeenSpeaker/T5TTS--val_loss=0.3125-epoch=8.ckpt \ --legacy_codebooks \ + --legacy_text_conditioning \ --apply_attention_prior \ --clean_up_disk \ --cer_target 0.3 \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh index 098b04e7a637..228f7e97fc35 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh @@ -23,6 +23,7 @@ coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/ne --hparams_files /home/TestData/tts/2506_ZeroShot/lrhm_short_yt_prioralways_alignement_0.002_priorscale_0.1.yaml \ --checkpoint_files /home/TestData/tts/2506_ZeroShot/dpo-T5TTS--val_loss=0.4513-epoch=3.ckpt \ --legacy_codebooks \ + --legacy_text_conditioning \ --apply_attention_prior \ --clean_up_disk \ --cer_target 0.1 \ From 6fa892157d4bce83212f5d8c3a0d0bc55b8f0b84 Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Thu, 21 Aug 2025 15:23:17 -0400 Subject: [PATCH 076/113] Frame Stacking (#14455) * Frame stacking implementation + MaskGit refinements Squashed commit of the many commits from the following branch: `magpietts_2503_maskgit_dev` Signed-off-by: Fejgin, Roy * Evalset config tweak Signed-off-by: Fejgin, Roy * Docstring Signed-off-by: Fejgin, Roy * Cleanup Signed-off-by: Fejgin, Roy * Docs and varaible names - Add README for frame-stacking - Rename some variables for clarity Signed-off-by: Fejgin, Roy * README Signed-off-by: Fejgin, Roy * Cleanup and comments Signed-off-by: Fejgin, Roy * Cleanup, comments, docs Signed-off-by: Fejgin, Roy * README update Signed-off-by: Fejgin, Roy * Fix end-of-sequence handling during training with frame stacking Signed-off-by: Fejgin, Roy * Address PR comments * Disallow sampling of reserved tokens (MaskGit only for now). * Move code handling list of forbidden tokens to inside the SpecialAudioTokens class and refactor it a bit. * Use `self.num_audio_codebooks` rather than rely on tensor shape. Signed-off-by: Fejgin, Roy * Forbidden tokens - cleanup / refactor Signed-off-by: Fejgin, Roy * Fix path in config_evalset.py Fixed to the actual name of the file on EOS Signed-off-by: Fejgin, Roy * Add an option to seed the random number generators for debugging / reproducibility. Signed-off-by: Fejgin, Roy * Refine EOS detection When frame stacking is active we need to not only detect which decoder step included the first EOS but also detect which frame within the stack frame included the EOS. We then terminate the sequence at exactly the right position. Signed-off-by: Fejgin, Roy * Fix a print statement Signed-off-by: Fejgin, Roy * Fix checkpoint name string Signed-off-by: Fejgin, Roy * Cleanup and configuration validity checking ... for frame-stacking. Signed-off-by: Fejgin, Roy * Avoid speaker embedding crash on short signals The alternate speaker embedding model was not using the extract_embedding() method, which takes care of padding very short signals (which can sometimes be generated during evaluation) so that they don't crash the embedding model. This caused crashes during evaluation. Fixed it to use extract_embedding() like the default speaker embedding model does. Signed-off-by: Fejgin, Roy * Comments Signed-off-by: Fejgin, Roy * Inference: Fix for speaker embedding of short files The padded signal wasn't actually being given to the embedding model. Needed to write it to a temporary file first. Signed-off-by: Fejgin, Roy * Address Copilot's review comments Signed-off-by: Fejgin, Roy * Cleanup (per Copilot's suggestions) Signed-off-by: Fejgin, Roy * Address PR comments * Replace print statements with logging.debug * Add comments * Use `self.frame_stacking_factor` instead of a function argument Signed-off-by: Fejgin, Roy * Turn off MaskGit debug logging Also, reset the verbosity level to the original value after the debug logging. Signed-off-by: Fejgin, Roy * Control debug logging from env variable Signed-off-by: Fejgin, Roy * Clean up debug logging code Signed-off-by: Fejgin, Roy * Remove and/or disable debug code Signed-off-by: Fejgin, Roy * Remove more debug code Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- examples/tts/README_frame_stacking.md | 26 + nemo/collections/tts/models/magpietts.py | 543 ++++++++++++------ .../tts/modules/magpietts_modules.py | 26 +- nemo/collections/tts/parts/utils/helpers.py | 4 + scripts/magpietts/infer_and_evaluate.py | 49 +- 5 files changed, 466 insertions(+), 182 deletions(-) create mode 100644 examples/tts/README_frame_stacking.md diff --git a/examples/tts/README_frame_stacking.md b/examples/tts/README_frame_stacking.md new file mode 100644 index 000000000000..3be970c0446a --- /dev/null +++ b/examples/tts/README_frame_stacking.md @@ -0,0 +1,26 @@ +# Overview +This PR introduces frame-stacking implementation in Magpie-TTS. Frame-stacking is disabled by default. It can be enabled by setting a `frame_stacking_factor` > 1 in the YAML config. + +# Frame-stacking + +## Overview +Frame-stacking is a technique that allows the Magpie-TTS **base decoder** (also known as the "main" for "first stage" decoder) to **process multiple consecutive audio frames in a single forward pass**, leaving the job of generating individual frames and codebooks to a second, smaller, "Local Transformer" ("LT") decoder. The goal is to accelerate inference by reducing the number of generation steps of the base decoder. In this two-stage approach: + +1. The base decoder processes multiple frames at once, producing a single latent representation for each group (stack) of frames +2. The Local Transformer then generates the individual `frames * codebooks` tokens. + +The Local Transformer is much faster than the base decoder, making this two-stage approach significantly faster than generating each frame with the base decoder. The speed improvement comes from two factors: +* **Fewer parameters**: The LT decoder is lightweight compared to the base decoder +* **Shorter sequences**: The LT decoder only attends to the current frame stack and the latent, not the entire frame sequence + +The base decoder can also generate audio codes directly without a LT, but when frame-stacking is enabled using the LT decoder is typically necessary to achieve high-quality synthesis. + +## Design and Implementation +* The `frame_stacking_factor` is the parameter that controls the number of frames to stack. The default is 1, which means no frame-stacking. We have tested values up to `4`. +* For each codebooks, we keep a separate embedding table for at each frame within the stack. At the input to the decoder, the embeddings are averages across codebooks (as usual) and also frames within the stack. The embedding tables are shared between the base and LT decoders. + +## Limitations +This is still WIP with more work to be done. Specifically, the following are not yet implemented / tested: +* Online code extraction combined with frame-stacking. +* Alignment encoder with frame-stacking. +* CTC loss with frame-stacking. \ No newline at end of file diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index a692abad3586..4c4e59141c69 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -14,7 +14,8 @@ import os import random import time -from typing import List +from typing import List, Optional +from functools import partial import soundfile as sf import torch import wandb @@ -24,6 +25,7 @@ from omegaconf import DictConfig, OmegaConf, open_dict from torch import nn from torch.utils.data import get_worker_info +import numpy as np import nemo.collections.asr as nemo_asr from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config @@ -32,7 +34,12 @@ from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 from nemo.collections.tts.modules.aligner import AlignmentEncoder -from nemo.collections.tts.modules.magpietts_modules import CharAwareSubwordEncoder, SpecialAudioToken, LocalTransformerType, cosine_schedule +from nemo.collections.tts.modules.magpietts_modules import ( + CharAwareSubwordEncoder, + SpecialAudioToken, + LocalTransformerType, + cosine_schedule, +) from nemo.collections.tts.parts.utils.helpers import ( binarize_attention_parallel, get_mask_from_lengths, @@ -96,14 +103,21 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # Our codebooks start with actual audio codec tokens, followed by special tokens. # The `forced_*` options are for backward compatibility for models trained with older code. num_audio_tokens = codec_model.codebook_size - self.audio_bos_id = cfg.get('forced_audio_bos_id', num_audio_tokens + SpecialAudioToken.AUDIO_BOS.value) - self.audio_eos_id = cfg.get('forced_audio_eos_id', num_audio_tokens + SpecialAudioToken.AUDIO_EOS.value) - self.context_audio_bos_id = cfg.get('forced_context_audio_bos_id', num_audio_tokens + SpecialAudioToken.AUDIO_CONTEXT_BOS.value) - self.context_audio_eos_id = cfg.get('forced_context_audio_eos_id', num_audio_tokens + SpecialAudioToken.AUDIO_CONTEXT_EOS.value) - self.num_all_tokens_per_codebook = cfg.get('forced_num_all_tokens_per_codebook',num_audio_tokens + len(SpecialAudioToken)) - self.mask_token_id = cfg.get('forced_mask_token_id', num_audio_tokens + SpecialAudioToken.MASK_TOKEN.value) + get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=num_audio_tokens) + self.audio_bos_id = cfg.get('forced_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_BOS)) + self.audio_eos_id = cfg.get('forced_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_EOS)) + self.context_audio_bos_id = cfg.get('forced_context_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS)) + self.context_audio_eos_id = cfg.get('forced_context_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS)) + self.mask_token_id = cfg.get('forced_mask_token_id', get_token_index(SpecialAudioToken.MASK_TOKEN)) + self.num_all_tokens_per_codebook = cfg.get('forced_num_all_tokens_per_codebook', num_audio_tokens + len(SpecialAudioToken)) self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False) + # The frame stacking factor controls how many consecutive frames are processed together by the base decoder + # (and then refined into individual frames by the local transformer). A frame stacking factor of 1 means no + # frame stacking. We have a separate embedding table for each of the stacked frames, e.g. for frame stacking + # factor of 3, the entries of codebook 0 appear 3 times in the embedding table. + self.frame_stacking_factor = cfg.get('frame_stacking_factor', 1) + assert 'downsample_factor' not in cfg, '`downsample_factor` is deprecated, use `frame_stacking_factor` instead' # Setup tokenizer if hasattr(cfg, 'text_tokenizer'): # For backward compatibility for English-only models @@ -167,7 +181,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self._codec_model.freeze() #Lightning does requires_grad = False and self.eval() audio_embeddings = [] - for _ in range(self.num_audio_codebooks): + for _ in range(self.num_audio_codebooks * self.frame_stacking_factor): audio_embeddings.append(nn.Embedding(self.num_all_tokens_per_codebook, cfg.embedding_dim)) self.audio_embeddings = nn.ModuleList(audio_embeddings) @@ -199,8 +213,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) - - self.final_proj = nn.Linear(cfg.decoder.d_model, self.num_audio_codebooks * self.num_all_tokens_per_codebook) + self.final_proj = nn.Linear(cfg.decoder.d_model, self.num_audio_codebooks * self.num_all_tokens_per_codebook * self.frame_stacking_factor) self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower()) logging.info(f"Local transformer type: {self.local_transformer_type}") @@ -217,11 +230,11 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): sa_n_heads=self.cfg.get('local_transformer_n_heads', 1), kernel_size=1, is_causal=self.local_transformer_type == LocalTransformerType.AR, - max_length_causal_mask=self.num_audio_codebooks+2, + max_length_causal_mask=self.frame_stacking_factor*self.num_audio_codebooks+2, use_learnable_pos_emb=True, ) local_transformer_out_projections = [] - for _ in range(self.num_audio_codebooks): + for _ in range(self.num_audio_codebooks * self.frame_stacking_factor): # Have a separate projection layer for each codebook, to distinguish between them local_transformer_out_projections.append(nn.Linear(local_transformer_hidden_dim, self.num_all_tokens_per_codebook)) self.local_transformer_out_projections = nn.ModuleList(local_transformer_out_projections) @@ -274,7 +287,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): elif self.model_type == 'decoder_pretrain_synthesizer': # This is for pretraining the decoder only on audio data using next frame prediction loss - assert cfg.alignment_loss_scale == 0.0, "Alignment loss is not supported for decoder pretrain synthesizer" + assert cfg.alignment_loss_scale == 0.0, "Alignment loss is not supported for decoder pretrain synthesizer" else: raise ValueError(f"Unsupported model type {self.model_type}") @@ -308,6 +321,9 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.aligner_encoder_train_steps = self.cfg.get('aligner_encoder_train_steps', float('inf')) self.dec_random_input_max = self.cfg.get('dec_random_input_max', self.num_all_tokens_per_codebook) + # Configuration validity checks + self.check_frame_stacking_config_validity() + def state_dict(self, destination=None, prefix='', keep_vars=False): """ @@ -324,6 +340,27 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): del state_dict[key] return state_dict + def check_frame_stacking_config_validity(self): + """ + Check if the configuration is compatible with frame stacking. + """ + if self.frame_stacking_factor > 1: + # The settings below are not supported with frame stacking. + # Some of them may work - but they have not been tested. + + # disallow alignment encoder + if self.use_alignment_encoder: + raise ValueError("Alignment encoder is not supported for frame stacking") + # disallow alignment loss + if self.alignment_loss_scale > 0.0: + raise ValueError("Alignment loss is not supported for frame stacking") + # disallow training prior + if self.cfg.prior_scaling_factor is not None and self.cfg.prior_scaling_factor > 0: + raise ValueError("Training-time attention prior is not supported for frame stacking") + # disallow text conditioning + if self.use_text_conditioning_encoder: + raise ValueError("Text conditioning is not supported for frame stacking") + def update_ckpt(self, state_dict): """ Backward compatibility for checkpoints saved with old model names. @@ -386,6 +423,8 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): bos_tensor = torch.full( (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device ) + # pad at the end to make room for the EOS token; the EOS token's actual position + # varies per batch element depending on each element's length. pad_tensor = torch.full( (codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device ) # 0 is the padding token in the audio codebook @@ -394,8 +433,7 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): # codes_len: (B,) for idx in range(codes.size(0)): codes[idx, :, codes_len[idx] + 1] = audio_eos_id - codes_len = codes_len + 2 - + codes_len = codes_len + 2 # +1 for bos and +1 for eos return codes.long(), codes_len.long() def codes_to_audio(self, codes, codes_len): @@ -415,16 +453,17 @@ def codes_to_audio(self, codes, codes_len): return audio, audio_len def embed_audio_tokens(self, audio_tokens): - # audio_tokens: (B, C, T') - # Add and average the embeddings of the audio tokens across the codebooks + B, C, T = audio_tokens.shape audio_embedding = None - for c in range(audio_tokens.size(1)): - embedding = self.audio_embeddings[c](audio_tokens[:, c, :]) - if audio_embedding is None: - audio_embedding = embedding - else: - audio_embedding = audio_embedding + embedding - audio_embedding = audio_embedding / audio_tokens.size(1) + for i in range(self.frame_stacking_factor): + for c in range(C): + tokens = audio_tokens[:,c , i::self.frame_stacking_factor] + embedding = self.audio_embeddings[c + i * C](tokens) + if audio_embedding is None: + audio_embedding = embedding + else: + audio_embedding += embedding + audio_embedding = audio_embedding / (C * self.frame_stacking_factor) return audio_embedding def get_speaker_embeddings(self, audio_16khz, audio_len_16khz): @@ -445,31 +484,39 @@ def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_ (using an 8-codebook setup as an example): +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ | AR target | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | none | + | codebook | | | | | | | | | | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ | MG target | none | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + | codebook | | | | | | | | | | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ - | Input | Magpie | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - | | Latent | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | + | input | Magpie | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + | codebook | latent | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ - | Seq. Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | + | seq. index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ - + dec_out: (B, T', E) audio_codes_target: (B, C, T') targets_offset_by_one: bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive) if True, the target for index 1 is codebook 0, for index 2 is codebook 1, etc. (MaskGit) """ + C = self.num_audio_codebooks dec_out_all = dec_out.reshape(-1, dec_out.size(-1)) # (B*T', E) local_transformer_input = [dec_out_all] - for codebook_num in range(audio_codes_target.size(1)): - codes = audio_codes_target[:, codebook_num] # (B, T') - codes = codes.reshape(-1) # (B*T',) - codebook_embedding = self.audio_embeddings[codebook_num](codes) # (B*T', E) - local_transformer_input.append(codebook_embedding) - + # Build the teacher-forced input to the LT. + for fs_index in range(self.frame_stacking_factor): + for codebook_num in range(C): + # Collect ground truth codes for the current codebook and frame stack index combintation. + codes = audio_codes_target[:, codebook_num, fs_index::self.frame_stacking_factor] # (B, T') + # Individual timesteps are independently handled by the LT fold time into the batch dimension. + codes = codes.reshape(-1) # (B*T',) + # Embed the codes + codebook_embedding = self.audio_embeddings[codebook_num + fs_index * C](codes) # (B*T', E) + local_transformer_input.append(codebook_embedding) + # Stack the input codes along dimension 1 (codebooks). This is the dimension along which the LT predicts iteratively. local_transformer_input = torch.stack(local_transformer_input, dim=1) # (B*T', C+1, E) local_transformer_input = self.local_transformer_in_projection(local_transformer_input) # (B*T', C+1, 128) - _mask = torch.ones( local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device) + _mask = torch.ones(local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device) local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B*T', C+1, E) if not targets_offset_by_one: # for autoregressive local transformer the target for index 0 is codebook 0, for index 1 is codebook 1, etc. @@ -478,16 +525,17 @@ def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_ # for MaskGit the target for index **1** is codebook 0, for index 2 is codebook 1, etc. local_transformer_output = local_transformer_output[:, 1:, :] # (B*T', C, E) all_code_logits = [] - for codebook_num in range(audio_codes_target.size(1)): - # Using a separate projection layer for each codebook (to distinguish between them) - # Checked the time - this loop is not taking much time (compared to the local transformer forward pass) - codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, codebook_num, :]) # (B*T', num_all_tokens_per_codebook) - all_code_logits.append(codebook_logits) - all_code_logits = torch.cat(all_code_logits, dim=1) # (B*T', num_codebooks * num_all_tokens_per_codebook) + for fs_index in range(self.frame_stacking_factor): + for codebook_num in range(audio_codes_target.size(1)): + # Using a separate projection layer for each codebook (to distinguish between them) + # Checked the time - this loop is not taking much time (compared to the local transformer forward pass) + codebook_logits = self.local_transformer_out_projections[codebook_num + fs_index*C](local_transformer_output[:, codebook_num + fs_index*C, :]) # (B*T', num_all_tokens_per_codebook) + all_code_logits.append(codebook_logits) + all_code_logits = torch.cat(all_code_logits, dim=1) # (B*T'/frame_stacking_factor, num_codebooks * num_all_tokens_per_codebook * frame_stacking_factor) all_code_logits = all_code_logits.view( - audio_codes_target.size(0), audio_codes_target.size(2), -1 - ) # (B, T', C * num_all_tokens_per_codebook) + audio_codes_target.size(0), audio_codes_target.size(2) // self.frame_stacking_factor, -1 + ) # (B, T'/frame_stacking_factor, C * num_all_tokens_per_codebook * frame_stacking_factor) return all_code_logits @@ -532,7 +580,7 @@ def maskgit_apply_random_mask(self, codes): codes_with_mask = torch.where(mask, self.mask_token_id, codes) return codes_with_mask, mask - def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=None): + def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=None, frame_stacking_factor=1): """ Computes the audio codebook loss. Used by (1) The main Magpie-TTS transformer @@ -542,9 +590,10 @@ def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=N audio_codes: (B, C, T') audio_codes_lens: (B,) mask_tokens_mask: (B, C, T') True for tokens that were replaced with the MASK_TOKEN and should - therefore be the only ones included in the loss computation. + therefore be the only ones included in the loss computation (for MaskGit). + frame_stacking_factor: int, the stacking factor used in the model """ - loss_mask = get_mask_from_lengths(audio_codes_lens) + loss_mask = get_mask_from_lengths(audio_codes_lens, pad_to_factor=frame_stacking_factor) if mask_tokens_mask is not None: # For MaskGit we only compute loss for the masked tokens. # *Both* conditions must be true: @@ -559,22 +608,27 @@ def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=N # repeat loss mask for each codebook to simplify code below loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1) total_codebook_loss = None - for codebook in range(audio_codes.size(1)): - si = codebook * self.num_all_tokens_per_codebook - ei = si + self.num_all_tokens_per_codebook - codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook) - codebook_targets = audio_codes[:, codebook] # (B, T') - codebook_loss = self.cross_entropy_loss( - codebook_logits.permute(0, 2, 1), codebook_targets # (B, num_tokens_per_codebook, T') - ) # (B, T') - codebook_loss = codebook_loss * loss_mask[:, codebook, :] - codebook_loss = codebook_loss.sum() / loss_mask[:, codebook, :].sum() - if total_codebook_loss is None: - total_codebook_loss = codebook_loss - else: - total_codebook_loss = total_codebook_loss + codebook_loss + for fs_index in range(frame_stacking_factor): + for codebook in range(audio_codes.size(1)): + si = (codebook + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook + codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook) + codebook_targets = audio_codes[:, codebook, fs_index::frame_stacking_factor] # (B, T') + codebook_loss = self.cross_entropy_loss( + codebook_logits.permute(0, 2, 1), codebook_targets # (B, num_tokens_per_codebook, T') + ) # (B, T') + codebook_loss_mask = loss_mask[:, codebook, fs_index::frame_stacking_factor] + codebook_loss = codebook_loss * codebook_loss_mask + if codebook_loss_mask.sum() == 0: + logging.warning(f"Loss mask for codebook {codebook} is all zeros, global_step: {self.global_step}") + continue + codebook_loss = codebook_loss.sum() / codebook_loss_mask.sum() + if total_codebook_loss is None: + total_codebook_loss = codebook_loss + else: + total_codebook_loss = total_codebook_loss + codebook_loss - total_codebook_loss = total_codebook_loss / audio_codes.size(1) + total_codebook_loss = total_codebook_loss / (audio_codes.size(1) * frame_stacking_factor) return total_codebook_loss, loss_mask def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prior, multi_encoder_mapping): @@ -593,81 +647,154 @@ def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prio def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): # all_code_logits: (B, T', num_codebooks * num_tokens_per_codebook) # audio_codes_lens: (B,) - all_preds = [] - for idx in range(self.num_audio_codebooks): - si = idx * self.num_all_tokens_per_codebook - ei = si + self.num_all_tokens_per_codebook - codebook_logits = all_code_logits[:, :, si:ei] - codebook_probs = torch.softmax(codebook_logits, dim=-1) # (B, T', num_tokens_per_codebook) - # argmax to get the tokens - codebook_preds = torch.argmax(codebook_probs, dim=-1) # (B, T') - all_preds.append(codebook_preds) - - all_preds = torch.stack(all_preds, dim=1) # (B, C, T') + all_preds = [[] for _ in range(self.frame_stacking_factor)] + for fs_index in range(self.frame_stacking_factor): + for idx in range(self.num_audio_codebooks): + si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook + codebook_logits = all_code_logits[:, :, si:ei] + codebook_probs = torch.softmax(codebook_logits, dim=-1) # (B, T', num_tokens_per_codebook) + # argmax to get the tokens + codebook_preds = torch.argmax(codebook_probs, dim=-1) # (B, T') + all_preds[fs_index].append(codebook_preds) + all_preds = [torch.stack(p, dim=1) for p in all_preds] # list of `frame_stacking_factor`` elements of shape (B,C,T) each + all_preds = torch.stack(all_preds, dim=-1) # B, C, T, frame_stacking_factor + # undo the frame stacking + all_preds = all_preds.reshape(all_preds.size(0), all_preds.size(1), -1) # B, C, T*frame_stacking_factor + pred_max_len = all_preds.size(2) + real_max_len = audio_codes_lens.max() + assert (pred_max_len - real_max_len) < self.frame_stacking_factor + # trim padding introduced for frame stacking + all_preds = all_preds[:,:, :real_max_len] audio_mask = get_mask_from_lengths(audio_codes_lens) all_preds = all_preds * audio_mask.unsqueeze(1) return all_preds - def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0, n_steps=3): + def visualize_codes(self, codes, mask_id=2020, frame_stacking_rate=2): """ - Sample codes for one timestep from the local transformer using MaskGit. + Visualize codes for analysis purposes + codes: (B, C) + """ + def code_to_str(code): + if code==mask_id: + return "M " + else: + return f"{code:04d} " + B, C = codes.shape + if B > 1: + logging.debug("Warning: visualizing only first batch element") + codes = codes.clone().detach().cpu().numpy()[0] + codes = [code_to_str(c) for c in codes] + output_str = "" + for i, c in enumerate(codes): + if (i) % (C/frame_stacking_rate) == 0: + output_str += "|timestep| " + output_str += c + logging.debug(output_str) + + def clear_forbidden_logits(self, logits, clear_audio_eos=False): + """ + Sets logits for forbidden tokens to `-inf`. + Args: + logits: (B, C, num_audio_tokens_per_codebook) + clear_audio_eos: bool, whether to clear the audio EOS token + """ + logits[:,:, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=clear_audio_eos)] = float('-inf') + return logits + + def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0, n_steps=3, noise_scale=0.0, fixed_schedule=None, dynamic_cfg_scale=False, sampling_type=None): """ + Sample codes for one timestep from the local transformer using MaskGit. + """ # dec_output: (B, E) device = dec_output.device # disable KV cache since our transformer is not causal self.local_transformer.reset_cache(use_cache=False) dec_output = dec_output.unsqueeze(1) # (B, 1, E) local_transformer_input_init = self.local_transformer_in_projection(dec_output) # (B, 1, D) where D is the dimension of the local transformer - C = self.num_audio_codebooks + codebook_seq_len = self.num_audio_codebooks * self.frame_stacking_factor B = dec_output.size(0) - min_confidence = float("-inf") - max_confidence = 10000 # this needs to be large enough that unmasked items will always remain unmasked. # TODO @rfejgin: use float('inf')? - confidences = min_confidence * torch.ones(B, C, device=device) + min_confidence = 0 + # this needs to be large enough that unmasked items will always remain unmasked (even after noise addition) + # Setting it smaller could allow "regret", i.e. re-masking a codebook that was previously unmasked; we might want to try that + max_confidence = 5 + confidences = min_confidence * torch.ones(B, codebook_seq_len, device=device) # initialize to all masked - codes = self.mask_token_id * torch.ones((B, C), device=device, dtype=torch.long) + codes = self.mask_token_id * torch.ones((B, codebook_seq_len), device=device, dtype=torch.long) sampled_codes = codes.clone() + topk_indices = None + if fixed_schedule is not None: + n_steps = len(fixed_schedule) for step in range(n_steps): + # how far along we are in the unmasking process + progress = step / n_steps # get mask fraction - frac_masked = cosine_schedule(torch.tensor(step / (n_steps))) + frac_masked = cosine_schedule(torch.tensor(progress)) + if sampling_type == "causal" or sampling_type == "purity_causal": + frac_masked = torch.ones_like(frac_masked) * (1.0 - progress) # how many codebooks to mask - n_masked = torch.ceil(C * frac_masked).long() # TODO @rfejgin: should we force this to be initialized to exactly `C` (to avoid numerical issues)? - n_unmasked = C - n_masked + if fixed_schedule is None: + n_masked = torch.ceil(codebook_seq_len * frac_masked).long() + else: + n_masked = codebook_seq_len - fixed_schedule[step] + n_unmasked = codebook_seq_len - n_masked + + if sampling_type == "causal" or sampling_type == "purity_causal":# and n_unmasked <= self.num_audio_codebooks: + # force second frame not to be unmasked + n_frames_to_allow = int(np.floor(progress*self.frame_stacking_factor+1)) + confidences[:,n_frames_to_allow*self.num_audio_codebooks:] = min_confidence-1 # only tested for frame_stacking_factor=2 + # pick top-confidence codebooks up to n_unmasked _, topk_indices = torch.topk(confidences, k=n_unmasked, dim=1) + if use_cfg: + actual_batch_size = topk_indices.size(0) // 2 + assert (topk_indices[actual_batch_size:] == topk_indices[:actual_batch_size]).all(), f"Topk indices are not the same for conditional and unconditional codes" - # replace masks of the top-k confident codebooks with the the codes that were sampled for them + # replace masks of the top-k confident codebooks with the codes that were sampled for them unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices) codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) - + # build transformer input local_transformer_input = local_transformer_input_init - for codebook_num in range(C): + for codebook_num in range(codebook_seq_len): next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze(1) # (B, 1, 768) next_local_transformer_input = self.local_transformer_in_projection(next_local_transformer_input) # (B, 1, d_local) local_transformer_input = torch.cat([local_transformer_input, next_local_transformer_input], dim=1) # (B, codebook_num+1, d_local) # run transformer - _mask = torch.ones(B, C+1, device=device) + _mask = torch.ones(B, codebook_seq_len+1, device=device) local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, C+1, d_local) # get logits logits = [] - for codebook_num in range(C): + for codebook_num in range(codebook_seq_len): # The `codebook_num+1` is to drop first position which corresponds to the magpie latent codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, codebook_num+1, :]) # (B, num_audio_tokens_per_codebook) logits.append(codebook_logits) - logits = torch.stack(logits, dim=1) # (B, C, num_audio_tokens_per_codebook) + logits = torch.stack(logits, dim=1) # (B, C*frame_stacking_factor, num_audio_tokens_per_codebook) # apply CFG if use_cfg: actual_batch_size = logits.size(0) // 2 conditional_logits = logits[:actual_batch_size] unconditional_logits = logits[actual_batch_size:] - cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits + if not dynamic_cfg_scale: + current_cfg_scale = cfg_scale + else: + # gradually increase the scale until mid point through sampling, then reduce it again + progress = step / (n_steps-1) + #interp = -abs(progress-0.5)+0.5 # increase from 0..1 in the interval from start to midpoint and then go back to zero + #interp = 1.0 - progress # decrease from 1 to 0 + interp = progress # gradually increase from 0 to 1 + current_cfg_scale = (cfg_scale - 1) * interp + 1.0 # 1.0 --> cfg_scale --> 1.0 + cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits logits[:actual_batch_size] = cfg_logits + # Disallow generation of special tokens (except audio EOS which is handled separately) + logits = self.clear_forbidden_logits(logits, clear_audio_eos=False) + # handle unfinished and finished items for item_idx in unfinished_items: logits[item_idx, self.audio_eos_id] = float('-inf') @@ -681,23 +808,37 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, logits_rescored = logits.clone() logits_rescored[indices_to_remove] = float('-inf') probs = torch.softmax(logits_rescored / temperature, dim=-1) # (B, C, num_audio_tokens_per_codebook) - sampled_codes = torch.multinomial(probs.view(B*C, -1), 1).view(B, C) + sampled_codes = torch.multinomial(probs.view(B*codebook_seq_len, -1), 1).view(B, codebook_seq_len) if use_cfg: - # TODO @rfejgin: why do we need to keep second half of the batch? can probably optimize this sampled_codes[actual_batch_size:] = sampled_codes[:actual_batch_size] probs[actual_batch_size:] = probs[:actual_batch_size] - confidences = torch.gather(probs, dim=2, index=sampled_codes.unsqueeze(-1)).squeeze(-1) - - # set confidence to max for unmasked codebooks so that they will remain unmasked - confidences.scatter_(index=topk_indices, dim=1, src=max_confidence*torch.ones_like(topk_indices, dtype=torch.float)) - + if sampling_type != "purity_causal" and sampling_type != "purity_default": + confidences = torch.gather(probs, dim=2, index=sampled_codes.unsqueeze(-1)).squeeze(-1) + else: + # use the max probability across all tokens for each codebook as the confidence for each codebook; known as "purity sampling" + confidences = probs.max(dim=2)[0] # replace entries in sampled_codes with previously unmasked codebooks sampled_codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) - # optionally: add noise to confidences here (as in token-critic paper) (not implemented) - + # add noise to confidences (as in token-critic paper, https://arxiv.org/abs/2209.04439) + if noise_scale > 0.0: + # get noise from uniform distribution in the interval [-0.5, 0.5), scale it by `noise_scale`, + # and anneal it to 0 as we approach the end of the unmasking process + noise = (torch.rand_like(confidences) - 0.5) * noise_scale * (1-(step+2)/n_steps) # the +2 makes sure that by the last iteration the noise is exactly 0 + confidences += noise + # the conditional and unconditional get different noise and must be fixed to be the same again + confidences[actual_batch_size:] = confidences[:actual_batch_size] + confidence_eps = 0.1 + assert confidences.max() + confidence_eps < max_confidence, f"Predicted confidence is approaching max_confidence: {confidences.max()}" + # for unmasked codebooks, set confidence to max so that they will remain unmasked + confidences.scatter_(index=topk_indices, dim=1, src=max_confidence*torch.ones_like(topk_indices, dtype=torch.float)) codes = sampled_codes assert not (codes == self.mask_token_id).any(), f"Codes contain mask tokens after completion of MaskGit sampling" - if use_cfg: + + # break stacked groups of frames into individual frames + codes = codes.reshape(B, self.frame_stacking_factor, self.num_audio_codebooks).permute(0,2,1) # B, C, frame_stacking_factor + + if use_cfg: + # drop unconditional codes codes = codes[:actual_batch_size] return codes @@ -707,7 +848,7 @@ def local_transformer_sample_autoregressive(self, dec_output, temperature=0.7, t dec_output = dec_output.unsqueeze(1) # (B, 1, E) local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) all_preds = [] - for codebook_num in range(self.num_audio_codebooks): + for codebook_num in range(self.num_audio_codebooks * self.frame_stacking_factor): _mask = torch.ones( local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device) local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, T, 128) codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, -1, :]) # (B, num_all_tokens_per_codebook) @@ -737,36 +878,40 @@ def local_transformer_sample_autoregressive(self, dec_output, temperature=0.7, t next_local_transformer_input = self.local_transformer_in_projection(next_local_transformer_input) # (B, 1, 128) local_transformer_input = torch.cat([local_transformer_input, next_local_transformer_input], dim=1) # (B, T+1, 128) - all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks) + all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks * frame_stacking_factor) + all_preds = all_preds.reshape(-1, self.frame_stacking_factor, self.num_audio_codebooks).permute(0,2,1) # (B, num_codebooks, frame_stacking_factor) if use_cfg: all_preds = all_preds[:actual_batch_size] return all_preds - def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={}): # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep - all_preds = [] - for idx in range(self.num_audio_codebooks): - si = idx * self.num_all_tokens_per_codebook - ei = si + self.num_all_tokens_per_codebook - codebook_logits = all_code_logits_t[:, si:ei] # (B, num_tokens_per_codebook) - for item_idx in unfinished_items: - codebook_logits[item_idx, self.audio_eos_id] = float('-inf') - for item_idx in finished_items: - codebook_logits[item_idx, :] = float('-inf') - codebook_logits[item_idx, self.audio_eos_id] = 0.0 - codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) - indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( - -1 - ) # (B, num_tokens_per_codebook) - codebook_logits_rescored = codebook_logits.clone() - codebook_logits_rescored[indices_to_remove] = float('-inf') - - codebook_probs = torch.softmax(codebook_logits_rescored / temperature, dim=-1) # (B, num_tokens_per_codebook) - codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) - all_preds.append(codebook_preds) - all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks) + all_preds = [[] for _ in range(self.frame_stacking_factor)] + for fs_index in range(self.frame_stacking_factor): + for idx in range(self.num_audio_codebooks): + si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook + codebook_logits = all_code_logits_t[:, si:ei] # (B, num_tokens_per_codebook) + + for item_idx in unfinished_items: + codebook_logits[item_idx, self.audio_eos_id] = float('-inf') + for item_idx in finished_items: + codebook_logits[item_idx, :] = float('-inf') + codebook_logits[item_idx, self.audio_eos_id] = 0.0 + codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) + indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( + -1 + ) # (B, num_tokens_per_codebook) + codebook_logits_rescored = codebook_logits.clone() + codebook_logits_rescored[indices_to_remove] = float('-inf') + + codebook_probs = torch.softmax(codebook_logits_rescored / temperature, dim=-1) # (B, num_tokens_per_codebook) + codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) + all_preds[fs_index].append(codebook_preds) + + all_preds = [torch.cat(ds_preds, dim=1).long() for ds_preds in all_preds] # list of `frame_stacking_factor` elements, each of shape (B, num_codebooks) + all_preds = torch.stack(all_preds, dim=2) # (B, num_codebooks, frame_stacking_factor) return all_preds def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens, prefix="", dec_context_size=0): @@ -914,6 +1059,23 @@ def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_co ) return alignment_loss + def pad_audio_codes(self, audio_codes: torch.Tensor, frame_stacking_factor: int = 1, pad_token: int =0): + """ + Pads the time dimension of the audio codes to a multiple of the frame stacking factor. + Args: + audio_codes (torch.Tensor): B, C, T + frame_stacking_factor (int): The factor that frames will be stacked by. + pad_token (int): The token ID to pad with. + Returns: + B, C, T_padded + """ + T = audio_codes.size(2) + T_padded = int(np.ceil(T / frame_stacking_factor) * frame_stacking_factor) + if T_padded > T: + padding = pad_token * torch.ones(audio_codes.size(0), audio_codes.size(1), T_padded - T, device=audio_codes.device, dtype=audio_codes.dtype) + audio_codes = torch.cat([audio_codes, padding], dim=2) + return audio_codes + def embed_context_text(self, context_text_tokens): if self.legacy_text_conditioning: context_text_tokens = context_text_tokens - self.tokenizer.tokenizer_offsets[self.text_conditioning_tokenizer_name] @@ -965,7 +1127,8 @@ def prepare_context_tensors(self, batch): context_audio_codes, context_audio_codes_lens = self.audio_to_codes( batch['context_audio'], batch['context_audio_lens'], audio_type='context' ) - context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T', E) + context_audio_codes = self.pad_audio_codes(context_audio_codes, self.frame_stacking_factor, pad_token=0) + context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T/frame_stacking_factor, E) if self.use_text_conditioning_encoder: context_text_tokens = batch['context_text_tokens'] @@ -1000,6 +1163,7 @@ def prepare_context_tensors(self, batch): else: context_input_embedded = context_audio_embedded context_input_lens = context_audio_codes_lens + context_input_lens = torch.ceil(context_input_lens / self.frame_stacking_factor).to(context_input_lens.dtype) context_mask = get_mask_from_lengths(context_input_lens) @@ -1159,12 +1323,25 @@ def process_batch(self, batch, mode="train"): else: audio_codes = batch['audio_codes'] audio_codes_lens = batch['audio_codes_lens'] - - audio_codes_input = audio_codes[:, :, :-1] # B, C, T' - audio_codes_target = audio_codes[:, :, 1:] - audio_codes_lens_input = audio_codes_lens_target = audio_codes_lens - 1 + if self.frame_stacking_factor > 1: + # repeat the BOS token to frame_stacking_factor times. This is necessary since at inference + # we need to start autoregressive generation from a full stack indicating BOS. + # TODO: @rfejgin: this assert might be slow due to GPU/CPU sync + assert (audio_codes[:,:,0] == self.audio_bos_id).all(), "Audio codes do not start with BOS token" + audio_codes = torch.cat([torch.full((audio_codes.size(0), audio_codes.size(1), self.frame_stacking_factor - 1), self.audio_bos_id, device=audio_codes.device, dtype=audio_codes.dtype), audio_codes], dim=2) + audio_codes_lens += self.frame_stacking_factor - 1 # account for BOS repeat + audio_codes = self.pad_audio_codes(audio_codes, self.frame_stacking_factor, pad_token=0) + # Note: if a tensor lacks the `_unstacked` suffix, it can be assumed to to be in the frame-stacked domain + + # drop last (stacked) frame since it is not part of *input* + audio_codes_input_unstacked = audio_codes[:, :, :-self.frame_stacking_factor] # B, C, T' + # drop first (stacked) frame which contains BOS token(s) which are not part of *target* + audio_codes_target_unstacked = audio_codes[:, :, self.frame_stacking_factor:] + audio_codes_lens_input_unstacked = audio_codes_lens - 1 # don't count EOS for input + audio_codes_lens_target_unstacked = audio_codes_lens - self.frame_stacking_factor # don't count BOS for target + audio_codes_lens_input = torch.floor(audio_codes_lens_input_unstacked / self.frame_stacking_factor).long() audio_codes_embedded_all = self.embed_audio_tokens(audio_codes) # (B, T, E) # Computing this to be use in the alignment encoder - audio_codes_embedded = audio_codes_embedded_all[:, :-1, :] # (B, T', E) Input to the decoder + audio_codes_embedded = audio_codes_embedded_all[:, :-1, :] # (B, T', E) Input to the decoder; this is already in the frame-stacked domain, hence the -1 (not `frame_stacking_factor`) audio_codes_mask = get_mask_from_lengths(audio_codes_lens_input) use_cfg = ( @@ -1200,16 +1377,16 @@ def process_batch(self, batch, mode="train"): # which can cause errors when doing codes_to_audio for audio_codes_input. We are not currently calling codes_to_audio on # audio_codes_input so should not matter if we don't supply dec_random_input_max. random_audio_tokens = torch.randint( - 0, max_codebook_val, audio_codes_input.size(), device=audio_codes_input.device + 0, max_codebook_val, audio_codes_input_unstacked.size(), device=audio_codes_input_unstacked.device ) random_audio_tokens = random_audio_tokens * audio_codes_mask.unsqueeze(1) dec_dropout_mask = ( - torch.rand((1, 1, audio_codes_input.size(2)), device=audio_codes_input.device) + torch.rand((1, 1, audio_codes_input_unstacked.size(2)), device=audio_codes_input_unstacked.device) > self.decoder_input_dropout_prob ) # timestep_mask is True for timesteps to be kept - audio_codes_input = audio_codes_input * dec_dropout_mask + random_audio_tokens * (~dec_dropout_mask) - audio_codes_embedded = self.embed_audio_tokens(audio_codes_input) # (B, T', E) + audio_codes_input_unstacked = audio_codes_input_unstacked * dec_dropout_mask + random_audio_tokens * (~dec_dropout_mask) + audio_codes_embedded = self.embed_audio_tokens(audio_codes_input_unstacked) # (B, T', E) if context_tensors['additional_decoder_input'] is not None: dec_input_embedded = torch.cat([additional_decoder_input, audio_codes_embedded], dim=1) @@ -1267,32 +1444,41 @@ def process_batch(self, batch, mode="train"): dec_context_size = context_tensors['dec_context_size'] logits = logits[:, dec_context_size:, :] # Remove the context audio embeddings from the logits - codebook_loss, loss_mask = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target) + # Codebook loss (parallel) + codebook_loss, loss_mask = self.compute_loss( + logits, audio_codes_target_unstacked, + audio_codes_lens_target_unstacked, + frame_stacking_factor=self.frame_stacking_factor + ) + # Alignment loss alignment_loss = None if self.alignment_loss_scale > 0.0 and not disable_alignment_loss: text_lens = context_tensors['text_lens'] cross_attention_scores = [attn['cross_attn_probabilities'][1] for layer_idx, attn in enumerate(attn_info) if layer_idx in self.ctc_prior_layer_ids] alignment_loss = self.compute_alignment_loss( - cross_attention_scores, text_lens, audio_codes_lens_target, dec_context_size + cross_attention_scores, text_lens, audio_codes_lens_input, dec_context_size ) loss = self.codebook_loss_scale * codebook_loss + alignment_loss else: loss = self.codebook_loss_scale * codebook_loss + # Local Transformer loss local_transformer_loss = None local_transformer_logits = None if self.local_transformer_type != LocalTransformerType.NO_LT: if self.local_transformer_type == LocalTransformerType.MASKGIT: + ## Maskgit ## # randomly replace some positions with MASK_TOKEN - audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target) + audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target_unstacked) + # TODO @rfejgin: the very last position might be padding but the local transformer might look at it as part of + # of a pair where the first position is valid. Is this an issue? local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_masked, targets_offset_by_one=True) - #audio_codes_masked = audio_codes_masked[:, 1:, :] - local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target, audio_codes_lens_target, mask_tokens_mask) + local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target_unstacked, audio_codes_lens_target_unstacked, mask_tokens_mask, frame_stacking_factor=self.frame_stacking_factor) else: - # autoregressive + ## Autoregressive ## assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type" - local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_target, targets_offset_by_one=False) - local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target, audio_codes_lens_target, None) + local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_target_unstacked, targets_offset_by_one=False) + local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target_unstacked, audio_codes_lens_target_unstacked, None, frame_stacking_factor=self.frame_stacking_factor) loss = loss + self.local_transformer_loss_scale * local_transformer_loss if aligner_encoder_loss is not None: @@ -1308,8 +1494,8 @@ def process_batch(self, batch, mode="train"): 'loss_mask': loss_mask, 'alignment_loss': alignment_loss, 'aligner_encoder_loss': aligner_encoder_loss, - 'audio_codes_target': audio_codes_target, - 'audio_codes_lens_target': audio_codes_lens_target, + 'audio_codes_target': audio_codes_target_unstacked, + 'audio_codes_lens_target': audio_codes_lens_target_unstacked, 'text': context_tensors['text'], 'text_lens': context_tensors['text_lens'], 'context_audio_codes': context_tensors['context_audio_codes'], @@ -1572,6 +1758,23 @@ def get_inference_attention_plots(self, cross_attention_scores_all_timesteps, al return cross_attention_maps, headwise_cross_attention_maps + def find_eos_frame_index(self, codes) -> Optional[int]: + """ + Checks for EOS in the predicted codes. Returns the index of the first frame within the frame stack + that contains an EOS token across any codebook, or `None` if no EOS is found. + Args: + codes: (num_codebooks, frame_stacking_factor) + Returns: + index (within the frame stack) of the first frame with EOS, or `None` if no EOS is found + """ + eos_mask = (codes == self.audio_eos_id) # (codebooks, frame_stacking_factor) + eos_per_frame = eos_mask.any(dim=0) # (frame_stacking_factor,) - True if any codebook has EOS in this frame + # find first frame with EOS + if eos_per_frame.any(): + # return index of the first frame with EOS + return eos_per_frame.nonzero()[0].item() + return None + def infer_batch( self, batch, @@ -1590,8 +1793,11 @@ def infer_batch( compute_all_heads_attn_maps=False, use_local_transformer_for_inference=False, use_LT_kv_cache=True, - maskgit_n_steps=3 - ): + maskgit_n_steps=3, + maskgit_noise_scale=0.0, + maskgit_fixed_schedule=None, + maskgit_dynamic_cfg_scale=False, + maskgit_sampling_type=None): with torch.no_grad(): start_time = time.time() self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference) @@ -1599,9 +1805,9 @@ def infer_batch( context_tensors = self.prepare_context_tensors(batch) text = context_tensors['text'] audio_codes_bos = torch.full( - (text.size(0), self.num_audio_codebooks, 1), self.audio_bos_id, device=text.device + (text.size(0), self.num_audio_codebooks, self.frame_stacking_factor), self.audio_bos_id, device=text.device ).long() - audio_codes_lens = torch.full((text.size(0),), 1, device=text.device).long() + audio_codes_lens = torch.full((text.size(0),), 1, device=text.device).long() # intetionally 1 rather than self.frame_stacking_factor since this is in stacked form audio_codes_input = audio_codes_bos audio_codes_mask = get_mask_from_lengths(audio_codes_lens) @@ -1626,7 +1832,7 @@ def infer_batch( attended_timestep_counter = [{} for _ in range(text.size(0))] last_attended_timesteps = [[1 for _ in range(text.size(0))]] # Maintain a list of attended timesteps as we predict audio for each batch item time_to_first_prediction = 0.0 - for idx in range(max_decoder_steps): + for idx in range(max_decoder_steps // self.frame_stacking_factor): if idx == 1: time_to_first_prediction = time.time() - start_time if idx % 20 == 0: @@ -1763,39 +1969,49 @@ def infer_batch( finished_items=finished_items, use_cfg=use_cfg, cfg_scale=cfg_scale, - n_steps=maskgit_n_steps + n_steps=maskgit_n_steps, + noise_scale=maskgit_noise_scale, + fixed_schedule=maskgit_fixed_schedule, + dynamic_cfg_scale=maskgit_dynamic_cfg_scale, + sampling_type=maskgit_sampling_type ) else: raise ValueError(f"Local transformer inference requested by but local transformer type is {self.local_transformer_type}") - # TODO @rfejgin: should we add argmax sampling for EOS here too? - all_codes_next_argmax = audio_codes_next else: # Parallel sampling from all codebooks - audio_codes_next = self.sample_codes_from_logits(all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unfinished_items, finished_items=finished_items) # (B, num_codebooks) - all_codes_next_argmax = self.sample_codes_from_logits(all_code_logits_t, temperature=0.01, unfinished_items=unfinished_items, finished_items=finished_items) # (B, num_codebooks) + audio_codes_next = self.sample_codes_from_logits(all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unfinished_items, finished_items=finished_items) # (B, num_codebooks, frame_stacking_factor) + all_codes_next_argmax = self.sample_codes_from_logits(all_code_logits_t, temperature=0.01, unfinished_items=unfinished_items, finished_items=finished_items) # (B, num_codebooks, frame_stacking_factor) for item_idx in range(all_codes_next_argmax.size(0)): if item_idx not in end_indices: - eos_in_pred_tokens_argmax = (all_codes_next_argmax[item_idx] == self.audio_eos_id).any().item() - eos_in_pred_tokens_multinomial = (audio_codes_next[item_idx] == self.audio_eos_id).any().item() - if eos_in_pred_tokens_argmax or eos_in_pred_tokens_multinomial: - print("End detected for item {} at timestep {}".format(item_idx, idx)) - end_indices[item_idx] = idx + # check for EOS (including within the frame stack) + eos_frame_multinomial = self.find_eos_frame_index(audio_codes_next[item_idx]) + eos_frame_argmax = self.find_eos_frame_index(all_codes_next_argmax[item_idx]) + eos_frame_multinomial = eos_frame_multinomial if eos_frame_multinomial is not None else float('inf') + eos_frame_argmax = eos_frame_argmax if eos_frame_argmax is not None else float('inf') + # pick minimum of the two + frame_index = min(eos_frame_multinomial, eos_frame_argmax) + if frame_index != float('inf'): + global_index = idx * self.frame_stacking_factor + frame_index + end_indices[item_idx] = global_index + print(f"End detected for item {item_idx} at decoder timestep: {idx}") all_predictions.append(audio_codes_next) audio_codes_input = torch.cat( - [audio_codes_input, audio_codes_next.unsqueeze(-1)], dim=-1 + [audio_codes_input, audio_codes_next], dim=-1 ) # (B, C, T') - audio_codes_lens = audio_codes_lens + 1 + audio_codes_lens = audio_codes_lens + 1 # already in stacked form audio_codes_mask = get_mask_from_lengths(audio_codes_lens) if len(end_indices) == text.size(0) and len(all_predictions) >= 4: # Codec must be of atleast 4 timesteps to be decoded properly print("All ends reached") break tts_generation_time = time.time() - start_time - tts_generation_time_per_frame = tts_generation_time / len(all_predictions) - predicted_codes = torch.stack(all_predictions, dim=-1) # (B, num_codebooks, T') + tts_generation_time_per_frame = tts_generation_time / (len(all_predictions)*self.frame_stacking_factor) + # Concatenate the list of predictions along the time dimension. Note that when frame stacking is on, + # this also undoes the stacking. + predicted_codes = torch.cat(all_predictions, dim=-1) # (B, num_codebooks, T') predicted_lens = [end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0))] # Ensure that the codec is atleast of length 4 predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long() @@ -2020,4 +2236,3 @@ def setup_test_data(self, dataset_cfg): @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: return [] - diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index eda50b5bfd1e..d5c0bcaac7cc 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from enum import Enum from nemo.utils.enum import PrettyStrEnum import torch @@ -20,6 +21,7 @@ from torch import Tensor from nemo.core.classes.module import NeuralModule + class LocalTransformerType(PrettyStrEnum): """ Enum for the type of local transformer to use in the MagpieTTS model. @@ -35,7 +37,7 @@ class SpecialAudioToken(Enum): """ Enum for the special tokens to use in the MagpieTTS model. The special tokens are appended at the end of the codebook after the actual audio codec tokens. - The actual codeco index is this value below plus the number of codec tokens - do not use the Enum directly. + The actual embedding table index is the value below plus the number of codec tokens - do not use the Enum directly. """ AUDIO_BOS = 0 @@ -48,6 +50,28 @@ class SpecialAudioToken(Enum): RESERVED_2 = 6 RESERVED_3 = 7 + @staticmethod + def get_index(token: SpecialAudioToken, base_codebook_size: int): + """ + Returns the index of the special token in the embedding table. + """ + return base_codebook_size + token.value + + @staticmethod + def get_forbidden_tokens(base_codebook_size: int, forbid_audio_eos: bool = True) -> list[int]: + """ + Returns a list of token indices that should not be sampled or returned to user. + Args: + base_codebook_size (int): The size of the codec codebook (which is the first part of the embedding table). + forbid_audio_eos (bool): Whether to forbid the AUDIO_EOS token to be sampled. + * Set to `False` when internally generating tokens in MagpieTTS sampling + * Set to `True` when checking validity of tokens to be returned to user + or given to the codec for decoding + """ + all_special_tokens = list(SpecialAudioToken) + if not forbid_audio_eos: + all_special_tokens.remove(SpecialAudioToken.AUDIO_EOS) + return [SpecialAudioToken.get_index(token, base_codebook_size) for token in all_special_tokens] def cosine_schedule(x: torch.Tensor): """ diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index f0ca0d0133f6..1676f34a7f28 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -137,12 +137,14 @@ def binarize_attention_parallel(attn, in_lens, out_lens): def get_mask_from_lengths( lengths: Optional[torch.Tensor] = None, x: Optional[torch.Tensor] = None, + pad_to_factor: Optional[int] = None ) -> torch.Tensor: """Constructs binary mask from a 1D torch tensor of input lengths Args: lengths: Optional[torch.tensor] (torch.tensor): 1D tensor with lengths x: Optional[torch.tensor] = tensor to be used on, last dimension is for mask + pad_to_factor: Optional[int] = pad the mask to an integer multiple of this factor Returns: mask (torch.tensor): num_sequences x max_length binary tensor """ @@ -154,6 +156,8 @@ def get_mask_from_lengths( max_len = torch.max(lengths) else: max_len = x.shape[-1] + if pad_to_factor is not None: + max_len = torch.ceil(max_len / pad_to_factor) * pad_to_factor ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) mask = ids < lengths.unsqueeze(1) return mask diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 3a492ac520a0..f4de0f7eea8e 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -77,7 +77,10 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False, legacy_tex # For older checkpoints trained with a different parameter name model_cfg.local_transformer_type = "autoregressive" del model_cfg.use_local_transformer - + if hasattr(model_cfg, 'downsample_factor'): + # Backward compatibility for models trained with the config option`downsample_factor` which was later renamed to `frame_stacking_factor` + model_cfg.frame_stacking_factor = model_cfg.downsample_factor + del model_cfg.downsample_factor if legacy_codebooks: # Added to address backward compatibility arising from # https://github.com/blisc/NeMo/pull/64 @@ -273,6 +276,9 @@ def run_inference( confidence_level=0.95, use_local_transformer=False, maskgit_n_steps=3, + maskgit_noise_scale=0.0, + maskgit_fixed_schedule=None, + maskgit_sampling_type=None, legacy_codebooks=False, legacy_text_conditioning=False, clean_up_disk=False, @@ -326,22 +332,22 @@ def run_inference( else: exp_name = "" - checkpoint_name = "{}{}_Temp{}_Topk{}_Cfg_{}_{}_Prior_{}_LT_{}_MGsteps_{}_ST_{}_sched_{}".format( - exp_name, - checkpoint_name, - temperature, - topk, - use_cfg, - cfg_scale, - apply_attention_prior, - attention_prior_epsilon, - attention_prior_lookahead_window, - start_prior_after_n_audio_steps, - "".join([str(l) for l in estimate_alignment_from_layers]) if estimate_alignment_from_layers is not None else "None", - "".join([str(l) for l in apply_prior_to_layers]) if apply_prior_to_layers is not None else "None", - use_local_transformer, - maskgit_n_steps, - sv_model + # Build checkpoint name + checkpoint_name = ( + f"{exp_name}{checkpoint_name}_Temp{temperature}_Topk{topk}_Cfg_{use_cfg}_{cfg_scale}_" + f"Prior_{apply_attention_prior}_" + ) + if apply_attention_prior: + # Only add prior config details if prior is enabled (to avoid super long checkpoint names) + checkpoint_name += ( + f"{attention_prior_epsilon}_{attention_prior_lookahead_window}_{start_prior_after_n_audio_steps}_" + f"{''.join([str(l) for l in estimate_alignment_from_layers]) if estimate_alignment_from_layers is not None else 'None'}_" + f"{''.join([str(l) for l in apply_prior_to_layers]) if apply_prior_to_layers is not None else 'None'}_" + ) + checkpoint_name += ( + f"LT_{use_local_transformer}_" + f"MaskGit_{maskgit_n_steps}_{maskgit_sampling_type}_{''.join([str(l) for l in maskgit_fixed_schedule]) if maskgit_fixed_schedule is not None else 'None'}_" + f"SV_{sv_model}" ) dataset_meta_info = evalset_config.dataset_meta_info @@ -459,6 +465,9 @@ def run_inference( start_prior_after_n_audio_steps=start_prior_after_n_audio_steps, use_local_transformer_for_inference=use_local_transformer, maskgit_n_steps=maskgit_n_steps, + maskgit_noise_scale=maskgit_noise_scale, + maskgit_fixed_schedule=maskgit_fixed_schedule, + maskgit_sampling_type=maskgit_sampling_type ) all_rtf_metrics.append(rtf_metrics) @@ -591,6 +600,9 @@ def main(): parser.add_argument('--use_cfg', action='store_true') parser.add_argument('--use_local_transformer', action='store_true', help="Enables use of local transformer for inference; applies to both Autoregressive and MaskGit sampling.") parser.add_argument('--maskgit_n_steps', type=int, default=3) + parser.add_argument('--maskgit_noise_scale', type=float, default=0.0) + parser.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None) + parser.add_argument('--maskgit_sampling_type', default=None, choices=["default", "causal", "purity_causal", "purity_default"]) parser.add_argument('--cfg_scale', type=float, default=2.5) parser.add_argument('--apply_attention_prior', action='store_true') parser.add_argument('--attention_prior_epsilon', type=float, default=0.1) @@ -650,6 +662,9 @@ def main(): confidence_level=args.confidence_level, use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, + maskgit_noise_scale=args.maskgit_noise_scale, + maskgit_fixed_schedule=args.maskgit_fixed_schedule, + maskgit_sampling_type=args.maskgit_sampling_type, legacy_codebooks=args.legacy_codebooks, legacy_text_conditioning=args.legacy_text_conditioning, clean_up_disk=args.clean_up_disk, From bce14cee88fd806f6f0d8cb9dfa1cfcc90efd7ea Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Thu, 21 Aug 2025 19:42:56 -0400 Subject: [PATCH 077/113] Forbid sampling of special tokens except AUDIO_EOS (#14555) We have observed the model sometimes generating special tokens (e.g. a reserved token), which should never be sampled. Noted in particular when CFG is on. We now excplitly ensure that such tokens are never sampled by forcing the corresponding logits (post-CFG, pre top-k) to -inf. Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 4c4e59141c69..ff8806186bb7 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -693,14 +693,14 @@ def code_to_str(code): output_str += c logging.debug(output_str) - def clear_forbidden_logits(self, logits, clear_audio_eos=False): + def clear_forbidden_logits(self, logits): """ - Sets logits for forbidden tokens to `-inf`. + Sets logits of forbidden tokens to `-inf` so they will never be sampled. + Specifically, we forbid sampling of all special tokens except AUDIO_EOS. Args: logits: (B, C, num_audio_tokens_per_codebook) - clear_audio_eos: bool, whether to clear the audio EOS token """ - logits[:,:, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=clear_audio_eos)] = float('-inf') + logits[:, :, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=False)] = float('-inf') return logits def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0, n_steps=3, noise_scale=0.0, fixed_schedule=None, dynamic_cfg_scale=False, sampling_type=None): @@ -793,7 +793,7 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, logits[:actual_batch_size] = cfg_logits # Disallow generation of special tokens (except audio EOS which is handled separately) - logits = self.clear_forbidden_logits(logits, clear_audio_eos=False) + logits = self.clear_forbidden_logits(logits) # handle unfinished and finished items for item_idx in unfinished_items: @@ -865,6 +865,7 @@ def local_transformer_sample_autoregressive(self, dec_output, temperature=0.7, t codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 + codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(-1) # (B, num_tokens_per_codebook) codebook_logits_rescored = codebook_logits.clone() @@ -899,6 +900,7 @@ def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80, for item_idx in finished_items: codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 + codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 From 368f6bfa4e664c61d5cd00fd65c66fa048eaca9b Mon Sep 17 00:00:00 2001 From: blisc Date: Fri, 22 Aug 2025 21:05:11 +0000 Subject: [PATCH 078/113] Apply isort and black reformatting Signed-off-by: blisc --- examples/tts/magpietts.py | 11 +- .../common/data/lhotse/dataloader.py | 2 +- .../common/data/lhotse/sampling.py | 18 +- .../text_to_speech/tts_tokenizers.py | 3 +- .../tts/data/text_to_speech_dataset.py | 8 +- .../tts/data/text_to_speech_dataset_lhotse.py | 28 +- nemo/collections/tts/models/__init__.py | 2 +- nemo/collections/tts/models/audio_codec.py | 16 +- nemo/collections/tts/models/magpietts.py | 836 ++++++++++++------ .../magpietts_preference_optimization.py | 434 ++++++--- .../tts/modules/audio_codec_modules.py | 54 +- nemo/collections/tts/modules/fcd_metric.py | 19 +- .../tts/modules/magpietts_modules.py | 40 +- .../tts/modules/transformer_2501.py | 8 +- nemo/collections/tts/parts/utils/helpers.py | 4 +- scripts/magpietts/codec_extraction.py | 63 +- .../magpietts/dpo/create_preference_pairs.py | 18 +- .../magpietts/dpo/create_text_contextpairs.py | 9 +- scripts/magpietts/eval_squimmos.py | 18 +- scripts/magpietts/evalset_config.py | 50 +- scripts/magpietts/evaluate_generated_audio.py | 196 ++-- scripts/magpietts/infer_and_evaluate.py | 263 +++--- scripts/tts_dataset_to_lhotse/create_shars.py | 40 +- .../common/test_lhotse_dataloading.py | 4 + .../tts/modules/test_fcd_metric.py | 2 +- 25 files changed, 1378 insertions(+), 768 deletions(-) diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index 8c74111ebef2..c72da5318287 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -18,8 +18,8 @@ from nemo.collections.tts.models import ( MagpieTTSModel, - MagpieTTSModelOfflinePODataGen, MagpieTTSModelOfflinePO, + MagpieTTSModelOfflinePODataGen, MagpieTTSModelOnlinePO, ) from nemo.core.config import hydra_runner @@ -69,9 +69,7 @@ def main(cfg): elif mode == 'test': model = MagpieTTSModelOfflinePODataGen(cfg=cfg.model, trainer=trainer) else: - raise NotImplementedError( - f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}" - ) + raise NotImplementedError(f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}") model.maybe_init_from_pretrained_checkpoint(cfg=cfg) @@ -83,7 +81,9 @@ def main(cfg): logging.info("Starting testing...") trainer.test(model) else: - raise NotImplementedError(f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}") + raise NotImplementedError( + f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}" + ) logging.info("Training/testing completed successfully.") finally: # Ensure WandB completes all uploads before Python thread shutdown @@ -91,6 +91,7 @@ def main(cfg): # overwhelmed and fail to properly coordinate with WandB's background threads try: import wandb + if wandb.run is not None: logging.info("Finishing WandB run to prevent threading shutdown hang...") wandb.finish() diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index facaf49446d2..38de6bf14add 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -45,7 +45,6 @@ ) from nemo.collections.common.data.lhotse.sampling import ( BucketingFilter, - ValidationStatusFilter, CERFilter, ContextSpeakerSimilarityFilter, DurationFilter, @@ -55,6 +54,7 @@ TokenCountFilter, TokenPerSecondFilter, TokenPerTokenFilter, + ValidationStatusFilter, ) from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn from nemo.collections.common.prompts import PromptFormatter diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 37499aad6382..e17dd2bb3a2c 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -267,20 +267,27 @@ def __call__(self, example) -> bool: else: return True # does not apply to text etc. + class ValidationStatusFilter: """ Callable, returns ``True`` if a cut's validation status is equal to keep and ``False`` otherwise. Acts as a pass-through for objects of other type than Cut. """ + def __init__(self, keep: str = "pass") -> None: self.keep = keep def __call__(self, example) -> bool: - if isinstance(example, Cut) and example.has_custom("validation_status") and example.validation_status != self.keep: + if ( + isinstance(example, Cut) + and example.has_custom("validation_status") + and example.validation_status != self.keep + ): return False else: return True + class CERFilter: """ Callable, returns ``True`` if a cut's CER is less than max_cer and ``False`` otherwise. @@ -296,20 +303,27 @@ def __call__(self, example) -> bool: else: return True + class ContextSpeakerSimilarityFilter: """ Callable, returns ``True`` if a cut's context speaker similarity is greater than min_context_speaker_similarity and ``False`` otherwise. Acts as a pass-through for objects of other type than Cut. """ + def __init__(self, min_context_speaker_similarity: float | None) -> None: self.min_context_speaker_similarity = ifnone(min_context_speaker_similarity, -1) def __call__(self, example) -> bool: - if isinstance(example, Cut) and len(example.supervisions) > 0 and example.supervisions[0].has_custom("context_speaker_similarity"): + if ( + isinstance(example, Cut) + and len(example.supervisions) > 0 + and example.supervisions[0].has_custom("context_speaker_similarity") + ): return example.supervisions[0].context_speaker_similarity >= self.min_context_speaker_similarity else: return True + class TokenCountFilter: """ Callable, returns ``True`` if an example's number of tokens is in range [t_min, t_max] and ``False`` otherwise. diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index e6d0896beabe..8e5b0eae9e8f 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -1134,7 +1134,7 @@ def __init__(self, tokenizers: List[Union[BaseTokenizer, PreTrainedTokenizerBase elif hasattr(first_tokenizer, "pad"): # Defined in BaseTokenizer subclasses self.pad = first_tokenizer.pad else: - raise ValueError("AggregatedTTSTokenizer could not find a padding token in the first tokenizer") + raise ValueError("AggregatedTTSTokenizer could not find a padding token in the first tokenizer") def encode(self, text: str, tokenizer_name: str = None) -> List[int]: tokenizer = self.tokenizers[tokenizer_name] @@ -1144,4 +1144,3 @@ def encode(self, text: str, tokenizer_name: str = None) -> List[int]: def decode(self, tokens: List[int], tokenizer_name: str = None) -> str: tokenizer = self.tokenizers[tokenizer_name] return tokenizer.decode([token - self.tokenizer_offsets[tokenizer_name] for token in tokens]) - diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 4d680a56912d..6e1e20725c89 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -397,7 +397,7 @@ def __init__( max_duration=max_duration, volume_norm=volume_norm, ) - self.bos_id = bos_id # TODO @xueyang: this should be removed since no other places used it. + self.bos_id = bos_id # TODO @xueyang: this should be removed since no other places used it. self.eos_id = eos_id self.audio_bos_id = audio_bos_id self.audio_eos_id = audio_eos_id @@ -578,7 +578,9 @@ def __getitem__(self, index): if self.use_text_conditioning_tokenizer: if 'context_text' in data.manifest_entry: - context_tokens = self.text_tokenizer.encode(data.manifest_entry['context_text'], self.text_conditioning_tokenizer_name) + context_tokens = self.text_tokenizer.encode( + data.manifest_entry['context_text'], self.text_conditioning_tokenizer_name + ) example['has_text_context'] = True else: context_tokens = self.text_tokenizer.encode("[NO TEXT CONTEXT]", self.text_conditioning_tokenizer_name) @@ -612,7 +614,7 @@ def __getitem__(self, index): example['raw_text'] = data.manifest_entry['original_text'] else: example['raw_text'] = data.text - + example['language'] = data.manifest_entry.get('language', 'en') if "reward" in data.manifest_entry: diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index 31e8827165c1..d427c17fefb4 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -54,9 +54,9 @@ def setup_tokenizers(all_tokenizers_config, mode='train'): tokenizer.set_phone_prob(1.0) tokenizers.append(tokenizer) tokenizer_names.append(tokenizer_name) - + aggregated_tokenizer = AggregatedTTSTokenizer(tokenizers, tokenizer_names) # TTS Transcript tokenizer - + return aggregated_tokenizer @@ -127,6 +127,7 @@ class MagpieTTSLhotseDataset(torch.utils.data.Dataset): Used for lazy initialization within workers. Must be provided if tokenizers are not set externally. Defaults to None. """ + def __init__( self, sample_rate: int, @@ -216,7 +217,9 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: context_text_tokens_len_list = [] context_has_text_context_list = [] reward_list = [] - raw_text_list = [] # raw text here is the string of normalized text or text stored in the supervision segment. Used to distinguish from text tokens. + raw_text_list = ( + [] + ) # raw text here is the string of normalized text or text stored in the supervision segment. Used to distinguish from text tokens. for cut in cuts: speaker = cut.supervisions[0].speaker if not check_speaker_format(speaker): @@ -332,7 +335,9 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: ) context_audio_codes = torch.cat([context_bos_tensor, context_eos_tensor], dim=1) context_audio_codes_len = context_audio_codes.shape[1] - context_audio_codes_list.append(context_audio_codes.T) # transpose to (T, C) to use collate_matrices to process batch. + context_audio_codes_list.append( + context_audio_codes.T + ) # transpose to (T, C) to use collate_matrices to process batch. context_audio_codes_len_list.append(context_audio_codes_len) else: # @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes @@ -364,10 +369,14 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: if self.use_text_conditioning_tokenizer: if cut.supervisions[0].has_custom("context_text"): - context_text_tokens = self.text_tokenizer.encode(cut.supervisions[0].context_text, tokenizer_name=self.text_conditioning_tokenizer_name) + context_text_tokens = self.text_tokenizer.encode( + cut.supervisions[0].context_text, tokenizer_name=self.text_conditioning_tokenizer_name + ) has_text_context = True else: - context_text_tokens = self.text_tokenizer.encode("[NO TEXT CONTEXT]", tokenizer_name=self.text_conditioning_tokenizer_name) + context_text_tokens = self.text_tokenizer.encode( + "[NO TEXT CONTEXT]", tokenizer_name=self.text_conditioning_tokenizer_name + ) has_text_context = False if self.pad_context_text_to_max_duration: _required_len = ( @@ -445,12 +454,15 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: batch_dict["context_audio_lens"] = torch.IntTensor(context_audio_len_list) if len(context_audio_codes_list) > 0: # transpose back to (B, 8, T) from (B, T, 8). - batch_dict["context_audio_codes"] = collate_matrices(context_audio_codes_list, padding_value=0).transpose(1, 2) + batch_dict["context_audio_codes"] = collate_matrices(context_audio_codes_list, padding_value=0).transpose( + 1, 2 + ) batch_dict["context_audio_codes_lens"] = torch.IntTensor(context_audio_codes_len_list) if self.use_text_conditioning_tokenizer: batch_dict['context_text_tokens'] = collate_vectors( - tensors=context_text_tokens_list, padding_value=self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name] + tensors=context_text_tokens_list, + padding_value=self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name], ) batch_dict['context_text_tokens_lens'] = torch.IntTensor(context_text_tokens_len_list) batch_dict['has_text_context'] = torch.BoolTensor(context_has_text_context_list) diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index 612275ce35cc..37b6a9a50aaf 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -20,8 +20,8 @@ from nemo.collections.tts.models.magpietts import MagpieTTSModel from nemo.collections.tts.models.magpietts_preference_optimization import ( MagpieTTSModelOfflinePO, - MagpieTTSModelOnlinePO, MagpieTTSModelOfflinePODataGen, + MagpieTTSModelOnlinePO, ) from nemo.collections.tts.models.mixer_tts import MixerTTSModel from nemo.collections.tts.models.radtts import RadTTSModel diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 1a1597be138c..785fa62210d9 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -358,7 +358,7 @@ def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Te tokens = rearrange(tokens, 'B C T -> C B T') with default_precision(torch.float32): dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len) - dequantized = dequantized.to(self.dtype) # make sure dequantized is in the right dtype + dequantized = dequantized.to(self.dtype) # make sure dequantized is in the right dtype return dequantized @typecheck( @@ -411,7 +411,7 @@ def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch. """ # Convert a discrete representation to a dequantized vector for each frame dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len) - dequantized = dequantized.to(self.dtype) # make sure that the dequantized is in the model dtype + dequantized = dequantized.to(self.dtype) # make sure that the dequantized is in the model dtype # Apply decoder to obtain time-domain audio for each frame audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len) @@ -489,12 +489,12 @@ def _process_batch(self, batch): encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len) commit_loss = 0.0 - encoded = encoded.to(encoded.dtype) # make sure encoded is converted to the right dtype + encoded = encoded.to(encoded.dtype) # make sure encoded is converted to the right dtype else: commit_loss = 0.0 # [B, T] - encoded = encoded.to(self.dtype) # make sure vector quantizer output is in the model dtype + encoded = encoded.to(self.dtype) # make sure vector quantizer output is in the model dtype audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len) return audio, audio_len, audio_gen, commit_loss @@ -536,7 +536,9 @@ def training_step(self, batch, batch_idx): generator_losses = [] # stft does not support bf16, so make it run in fp32 - loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) + loss_mel_l1, loss_mel_l2 = self.mel_loss_fn( + audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len + ) if self.mel_loss_l1_scale: metrics["g_loss_mel_l1"] = loss_mel_l1 @@ -623,7 +625,9 @@ def on_train_epoch_end(self): def validation_step(self, batch, batch_idx): audio, audio_len, audio_gen, _ = self._process_batch(batch) - loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) + loss_mel_l1, loss_mel_l2 = self.mel_loss_fn( + audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len + ) loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) loss_si_sdr = self.si_sdr_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index ff8806186bb7..16d28be553e6 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -14,8 +14,10 @@ import os import random import time -from typing import List, Optional from functools import partial +from typing import List, Optional + +import numpy as np import soundfile as sf import torch import wandb @@ -25,7 +27,6 @@ from omegaconf import DictConfig, OmegaConf, open_dict from torch import nn from torch.utils.data import get_worker_info -import numpy as np import nemo.collections.asr as nemo_asr from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config @@ -36,8 +37,8 @@ from nemo.collections.tts.modules.aligner import AlignmentEncoder from nemo.collections.tts.modules.magpietts_modules import ( CharAwareSubwordEncoder, - SpecialAudioToken, LocalTransformerType, + SpecialAudioToken, cosine_schedule, ) from nemo.collections.tts.parts.utils.helpers import ( @@ -49,18 +50,17 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging + def worker_init_fn(worker_id): # For mp.set_start_method("spawn", force=True) # The dataset class should be picklable, so we initialize non-picklable objects here logging.info(f"Worker {worker_id} initializing...") worker_info = get_worker_info() dataset = worker_info.dataset # Get the dataset instance in this worker - tokenizer = setup_tokenizers( - dataset.tokenizer_config, - mode=dataset.dataset_type - ) + tokenizer = setup_tokenizers(dataset.tokenizer_config, mode=dataset.dataset_type) dataset.text_tokenizer = tokenizer + class MagpieTTSModel(ModelPT): """ Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context @@ -106,10 +106,16 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=num_audio_tokens) self.audio_bos_id = cfg.get('forced_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_BOS)) self.audio_eos_id = cfg.get('forced_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_EOS)) - self.context_audio_bos_id = cfg.get('forced_context_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS)) - self.context_audio_eos_id = cfg.get('forced_context_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS)) + self.context_audio_bos_id = cfg.get( + 'forced_context_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS) + ) + self.context_audio_eos_id = cfg.get( + 'forced_context_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS) + ) self.mask_token_id = cfg.get('forced_mask_token_id', get_token_index(SpecialAudioToken.MASK_TOKEN)) - self.num_all_tokens_per_codebook = cfg.get('forced_num_all_tokens_per_codebook', num_audio_tokens + len(SpecialAudioToken)) + self.num_all_tokens_per_codebook = cfg.get( + 'forced_num_all_tokens_per_codebook', num_audio_tokens + len(SpecialAudioToken) + ) self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False) # The frame stacking factor controls how many consecutive frames are processed together by the base decoder @@ -129,19 +135,19 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # Using google-t5/t5-small as default text conditioning tokenizer for backward compatibility. self.text_conditioning_tokenizer_name = cfg.get('text_conditioning_tokenizer_name', None) self.legacy_text_conditioning = cfg.get('legacy_text_conditioning', False) - + if self.legacy_text_conditioning: if self.text_conditioning_tokenizer_name is None: self.text_conditioning_tokenizer_name = "google-t5/t5-small" - + tokenizer_target = "AutoTokenizer" if self.text_conditioning_tokenizer_name == "google-t5/t5-small": tokenizer_target = "T5Tokenizer" with open_dict(cfg): cfg.text_tokenizers[self.text_conditioning_tokenizer_name] = { - '_target_' : tokenizer_target, - 'pretrained_model' : self.text_conditioning_tokenizer_name, + '_target_': tokenizer_target, + 'pretrained_model': self.text_conditioning_tokenizer_name, } elif self.text_conditioning_tokenizer_name is None: # If no text_conditioning_tokenizer_name is specified, use the first one as default @@ -178,7 +184,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # This needs to happen after super().__init__() self._codec_model = codec_model - self._codec_model.freeze() #Lightning does requires_grad = False and self.eval() + self._codec_model.freeze() # Lightning does requires_grad = False and self.eval() audio_embeddings = [] for _ in range(self.num_audio_codebooks * self.frame_stacking_factor): @@ -204,7 +210,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): d_embed=cfg.embedding_dim, llm_tokenizer_vocab=subword_vocab, subword_padding_idx=self.tokenizer.pad, - special_vocab=special_vocab + special_vocab=special_vocab, ) else: # Regular text embedding @@ -213,7 +219,10 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) - self.final_proj = nn.Linear(cfg.decoder.d_model, self.num_audio_codebooks * self.num_all_tokens_per_codebook * self.frame_stacking_factor) + self.final_proj = nn.Linear( + cfg.decoder.d_model, + self.num_audio_codebooks * self.num_all_tokens_per_codebook * self.frame_stacking_factor, + ) self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower()) logging.info(f"Local transformer type: {self.local_transformer_type}") @@ -226,17 +235,19 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.local_transformer = transformer_2501.Transformer( n_layers=self.cfg.get('local_transformer_n_layers', 2), d_model=local_transformer_hidden_dim, - d_ffn=local_transformer_hidden_dim*4, + d_ffn=local_transformer_hidden_dim * 4, sa_n_heads=self.cfg.get('local_transformer_n_heads', 1), kernel_size=1, is_causal=self.local_transformer_type == LocalTransformerType.AR, - max_length_causal_mask=self.frame_stacking_factor*self.num_audio_codebooks+2, + max_length_causal_mask=self.frame_stacking_factor * self.num_audio_codebooks + 2, use_learnable_pos_emb=True, ) local_transformer_out_projections = [] for _ in range(self.num_audio_codebooks * self.frame_stacking_factor): # Have a separate projection layer for each codebook, to distinguish between them - local_transformer_out_projections.append(nn.Linear(local_transformer_hidden_dim, self.num_all_tokens_per_codebook)) + local_transformer_out_projections.append( + nn.Linear(local_transformer_hidden_dim, self.num_all_tokens_per_codebook) + ) self.local_transformer_out_projections = nn.ModuleList(local_transformer_out_projections) if cfg.get('use_alignment_encoder', False): @@ -253,7 +264,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self._speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( model_name='titanet_large' ) - self._speaker_verification_model.freeze() #Lightning does requires_grad = False and self.eval() + self._speaker_verification_model.freeze() # Lightning does requires_grad = False and self.eval() self.speaker_projection_layer = nn.Linear(cfg.speaker_emb_dim, cfg.embedding_dim) self.transcript_decoder_layers = [ idx for idx in range(self.decoder.n_layers) @@ -287,7 +298,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): elif self.model_type == 'decoder_pretrain_synthesizer': # This is for pretraining the decoder only on audio data using next frame prediction loss - assert cfg.alignment_loss_scale == 0.0, "Alignment loss is not supported for decoder pretrain synthesizer" + assert cfg.alignment_loss_scale == 0.0, "Alignment loss is not supported for decoder pretrain synthesizer" else: raise ValueError(f"Unsupported model type {self.model_type}") @@ -324,7 +335,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # Configuration validity checks self.check_frame_stacking_config_validity() - def state_dict(self, destination=None, prefix='', keep_vars=False): """ Only used for saving checkpoints. On save, we remove _speaker_verification_model and _codec_model @@ -347,7 +357,7 @@ def check_frame_stacking_config_validity(self): if self.frame_stacking_factor > 1: # The settings below are not supported with frame stacking. # Some of them may work - but they have not been tested. - + # disallow alignment encoder if self.use_alignment_encoder: raise ValueError("Alignment encoder is not supported for frame stacking") @@ -376,7 +386,7 @@ def update_ckpt(self, state_dict): else: new_state_dict[key] = state_dict[key] return new_state_dict - + def load_state_dict(self, state_dict, strict=True): """ Modify load_state_dict so that we don't restore weights to _speaker_verification_model and _codec_model when @@ -388,10 +398,15 @@ def load_state_dict(self, state_dict, strict=True): if strict == False: super().load_state_dict(state_dict, strict=False) for name, child in self.named_children(): - if name in ['_speaker_verification_model', '_codec_model', '_reference_model', - 'eval_asr_model', 'eval_speaker_verification_model', - 'whisper_model', 'squim_objective_model' - ]: + if name in [ + '_speaker_verification_model', + '_codec_model', + '_reference_model', + 'eval_asr_model', + 'eval_speaker_verification_model', + 'whisper_model', + 'squim_objective_model', + ]: continue if any(param.numel() > 0 for param in child.parameters()): # If the module has parameters, we want to change the default mapping so that the state_dict gets @@ -401,7 +416,7 @@ def load_state_dict(self, state_dict, strict=True): for key in state_dict.keys(): name_with_dot = f"{name}." if key.startswith(name_with_dot): - new_state_dict[key[len(name_with_dot):]] = state_dict[key] + new_state_dict[key[len(name_with_dot) :]] = state_dict[key] child.load_state_dict(new_state_dict) def audio_to_codes(self, audio, audio_len, audio_type='target'): @@ -423,7 +438,7 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): bos_tensor = torch.full( (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device ) - # pad at the end to make room for the EOS token; the EOS token's actual position + # pad at the end to make room for the EOS token; the EOS token's actual position # varies per batch element depending on each element's length. pad_tensor = torch.full( (codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device @@ -433,7 +448,7 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): # codes_len: (B,) for idx in range(codes.size(0)): codes[idx, :, codes_len[idx] + 1] = audio_eos_id - codes_len = codes_len + 2 # +1 for bos and +1 for eos + codes_len = codes_len + 2 # +1 for bos and +1 for eos return codes.long(), codes_len.long() def codes_to_audio(self, codes, codes_len): @@ -457,7 +472,7 @@ def embed_audio_tokens(self, audio_tokens): audio_embedding = None for i in range(self.frame_stacking_factor): for c in range(C): - tokens = audio_tokens[:,c , i::self.frame_stacking_factor] + tokens = audio_tokens[:, c, i :: self.frame_stacking_factor] embedding = self.audio_embeddings[c + i * C](tokens) if audio_embedding is None: audio_embedding = embedding @@ -494,48 +509,54 @@ def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_ +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ | seq. index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ - + dec_out: (B, T', E) audio_codes_target: (B, C, T') targets_offset_by_one: bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive) if True, the target for index 1 is codebook 0, for index 2 is codebook 1, etc. (MaskGit) """ C = self.num_audio_codebooks - dec_out_all = dec_out.reshape(-1, dec_out.size(-1)) # (B*T', E) + dec_out_all = dec_out.reshape(-1, dec_out.size(-1)) # (B*T', E) local_transformer_input = [dec_out_all] # Build the teacher-forced input to the LT. for fs_index in range(self.frame_stacking_factor): for codebook_num in range(C): - # Collect ground truth codes for the current codebook and frame stack index combintation. - codes = audio_codes_target[:, codebook_num, fs_index::self.frame_stacking_factor] # (B, T') + # Collect ground truth codes for the current codebook and frame stack index combintation. + codes = audio_codes_target[:, codebook_num, fs_index :: self.frame_stacking_factor] # (B, T') # Individual timesteps are independently handled by the LT fold time into the batch dimension. - codes = codes.reshape(-1) # (B*T',) + codes = codes.reshape(-1) # (B*T',) # Embed the codes - codebook_embedding = self.audio_embeddings[codebook_num + fs_index * C](codes) # (B*T', E) + codebook_embedding = self.audio_embeddings[codebook_num + fs_index * C](codes) # (B*T', E) local_transformer_input.append(codebook_embedding) # Stack the input codes along dimension 1 (codebooks). This is the dimension along which the LT predicts iteratively. - local_transformer_input = torch.stack(local_transformer_input, dim=1) # (B*T', C+1, E) - local_transformer_input = self.local_transformer_in_projection(local_transformer_input) # (B*T', C+1, 128) - _mask = torch.ones(local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device) - local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B*T', C+1, E) + local_transformer_input = torch.stack(local_transformer_input, dim=1) # (B*T', C+1, E) + local_transformer_input = self.local_transformer_in_projection(local_transformer_input) # (B*T', C+1, 128) + _mask = torch.ones( + local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device + ) + local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B*T', C+1, E) if not targets_offset_by_one: # for autoregressive local transformer the target for index 0 is codebook 0, for index 1 is codebook 1, etc. - local_transformer_output = local_transformer_output[:, :-1, :] # (B*T', C, E) + local_transformer_output = local_transformer_output[:, :-1, :] # (B*T', C, E) else: # for MaskGit the target for index **1** is codebook 0, for index 2 is codebook 1, etc. - local_transformer_output = local_transformer_output[:, 1:, :] # (B*T', C, E) + local_transformer_output = local_transformer_output[:, 1:, :] # (B*T', C, E) all_code_logits = [] for fs_index in range(self.frame_stacking_factor): for codebook_num in range(audio_codes_target.size(1)): # Using a separate projection layer for each codebook (to distinguish between them) # Checked the time - this loop is not taking much time (compared to the local transformer forward pass) - codebook_logits = self.local_transformer_out_projections[codebook_num + fs_index*C](local_transformer_output[:, codebook_num + fs_index*C, :]) # (B*T', num_all_tokens_per_codebook) + codebook_logits = self.local_transformer_out_projections[codebook_num + fs_index * C]( + local_transformer_output[:, codebook_num + fs_index * C, :] + ) # (B*T', num_all_tokens_per_codebook) all_code_logits.append(codebook_logits) - all_code_logits = torch.cat(all_code_logits, dim=1) # (B*T'/frame_stacking_factor, num_codebooks * num_all_tokens_per_codebook * frame_stacking_factor) + all_code_logits = torch.cat( + all_code_logits, dim=1 + ) # (B*T'/frame_stacking_factor, num_codebooks * num_all_tokens_per_codebook * frame_stacking_factor) all_code_logits = all_code_logits.view( audio_codes_target.size(0), audio_codes_target.size(2) // self.frame_stacking_factor, -1 - ) # (B, T'/frame_stacking_factor, C * num_all_tokens_per_codebook * frame_stacking_factor) + ) # (B, T'/frame_stacking_factor, C * num_all_tokens_per_codebook * frame_stacking_factor) return all_code_logits @@ -544,13 +565,13 @@ def maskgit_create_random_mask(self, codes): Creates a mask where True indicates the positions that should be replaced with a MASK_TOKEN. """ # Codes: (B, C, T) - B,C,T = codes.shape + B, C, T = codes.shape # get a uniform random vector uniformly sampled from [0,1) ## Todo does it need to be inclusive on the right? - rand_values = torch.rand(B,T, device=codes.device) + rand_values = torch.rand(B, T, device=codes.device) # apply the cosine schedule frac_masked = cosine_schedule(rand_values) # how many positions to mask - n_masked = torch.ceil(frac_masked * C).long() # B,T + n_masked = torch.ceil(frac_masked * C).long() # B,T # start from all unmasked mask = torch.zeros_like(codes, dtype=torch.bool) # The code further below is the vectorized version of this: @@ -566,11 +587,11 @@ def maskgit_create_random_mask(self, codes): random_permutations = torch.argsort(torch.rand(B, C, T, device=codes.device), dim=1) # (B, C, T) # Create a mask tensor where each position indicates if it should be masked mask_indices = torch.arange(C, device=codes.device).view(1, C, 1) - mask = mask_indices < n_masked.view(B, 1, T) # (B, C, T) + mask = mask_indices < n_masked.view(B, 1, T) # (B, C, T) # Apply the random permutations to the mask mask = torch.gather(mask, 1, random_permutations) - return mask # (B, C, T) + return mask # (B, C, T) def maskgit_apply_random_mask(self, codes): # Randomly replaces some codes with the MASK_TOKEN with a proportion following the cosine schedule. @@ -628,7 +649,7 @@ def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=N else: total_codebook_loss = total_codebook_loss + codebook_loss - total_codebook_loss = total_codebook_loss / (audio_codes.size(1) * frame_stacking_factor) + total_codebook_loss = total_codebook_loss / (audio_codes.size(1) * frame_stacking_factor) return total_codebook_loss, loss_mask def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prior, multi_encoder_mapping): @@ -657,15 +678,17 @@ def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): # argmax to get the tokens codebook_preds = torch.argmax(codebook_probs, dim=-1) # (B, T') all_preds[fs_index].append(codebook_preds) - all_preds = [torch.stack(p, dim=1) for p in all_preds] # list of `frame_stacking_factor`` elements of shape (B,C,T) each - all_preds = torch.stack(all_preds, dim=-1) # B, C, T, frame_stacking_factor + all_preds = [ + torch.stack(p, dim=1) for p in all_preds + ] # list of `frame_stacking_factor`` elements of shape (B,C,T) each + all_preds = torch.stack(all_preds, dim=-1) # B, C, T, frame_stacking_factor # undo the frame stacking - all_preds = all_preds.reshape(all_preds.size(0), all_preds.size(1), -1) # B, C, T*frame_stacking_factor + all_preds = all_preds.reshape(all_preds.size(0), all_preds.size(1), -1) # B, C, T*frame_stacking_factor pred_max_len = all_preds.size(2) real_max_len = audio_codes_lens.max() assert (pred_max_len - real_max_len) < self.frame_stacking_factor # trim padding introduced for frame stacking - all_preds = all_preds[:,:, :real_max_len] + all_preds = all_preds[:, :, :real_max_len] audio_mask = get_mask_from_lengths(audio_codes_lens) all_preds = all_preds * audio_mask.unsqueeze(1) @@ -676,11 +699,13 @@ def visualize_codes(self, codes, mask_id=2020, frame_stacking_rate=2): Visualize codes for analysis purposes codes: (B, C) """ + def code_to_str(code): - if code==mask_id: + if code == mask_id: return "M " else: return f"{code:04d} " + B, C = codes.shape if B > 1: logging.debug("Warning: visualizing only first batch element") @@ -688,7 +713,7 @@ def code_to_str(code): codes = [code_to_str(c) for c in codes] output_str = "" for i, c in enumerate(codes): - if (i) % (C/frame_stacking_rate) == 0: + if (i) % (C / frame_stacking_rate) == 0: output_str += "|timestep| " output_str += c logging.debug(output_str) @@ -700,26 +725,44 @@ def clear_forbidden_logits(self, logits): Args: logits: (B, C, num_audio_tokens_per_codebook) """ - logits[:, :, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=False)] = float('-inf') + logits[ + :, :, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=False) + ] = float('-inf') return logits - def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0, n_steps=3, noise_scale=0.0, fixed_schedule=None, dynamic_cfg_scale=False, sampling_type=None): + def local_transformer_sample_maskgit( + self, + dec_output, + temperature=0.7, + topk=80, + unfinished_items={}, + finished_items={}, + use_cfg=False, + cfg_scale=1.0, + n_steps=3, + noise_scale=0.0, + fixed_schedule=None, + dynamic_cfg_scale=False, + sampling_type=None, + ): """ Sample codes for one timestep from the local transformer using MaskGit. - """ + """ # dec_output: (B, E) device = dec_output.device # disable KV cache since our transformer is not causal self.local_transformer.reset_cache(use_cache=False) - dec_output = dec_output.unsqueeze(1) # (B, 1, E) - local_transformer_input_init = self.local_transformer_in_projection(dec_output) # (B, 1, D) where D is the dimension of the local transformer + dec_output = dec_output.unsqueeze(1) # (B, 1, E) + local_transformer_input_init = self.local_transformer_in_projection( + dec_output + ) # (B, 1, D) where D is the dimension of the local transformer codebook_seq_len = self.num_audio_codebooks * self.frame_stacking_factor B = dec_output.size(0) min_confidence = 0 # this needs to be large enough that unmasked items will always remain unmasked (even after noise addition) # Setting it smaller could allow "regret", i.e. re-masking a codebook that was previously unmasked; we might want to try that - max_confidence = 5 + max_confidence = 5 confidences = min_confidence * torch.ones(B, codebook_seq_len, device=device) # initialize to all masked codes = self.mask_token_id * torch.ones((B, codebook_seq_len), device=device, dtype=torch.long) @@ -741,39 +784,55 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, n_masked = codebook_seq_len - fixed_schedule[step] n_unmasked = codebook_seq_len - n_masked - if sampling_type == "causal" or sampling_type == "purity_causal":# and n_unmasked <= self.num_audio_codebooks: + if ( + sampling_type == "causal" or sampling_type == "purity_causal" + ): # and n_unmasked <= self.num_audio_codebooks: # force second frame not to be unmasked - n_frames_to_allow = int(np.floor(progress*self.frame_stacking_factor+1)) - confidences[:,n_frames_to_allow*self.num_audio_codebooks:] = min_confidence-1 # only tested for frame_stacking_factor=2 + n_frames_to_allow = int(np.floor(progress * self.frame_stacking_factor + 1)) + confidences[:, n_frames_to_allow * self.num_audio_codebooks :] = ( + min_confidence - 1 + ) # only tested for frame_stacking_factor=2 # pick top-confidence codebooks up to n_unmasked _, topk_indices = torch.topk(confidences, k=n_unmasked, dim=1) if use_cfg: actual_batch_size = topk_indices.size(0) // 2 - assert (topk_indices[actual_batch_size:] == topk_indices[:actual_batch_size]).all(), f"Topk indices are not the same for conditional and unconditional codes" + assert ( + topk_indices[actual_batch_size:] == topk_indices[:actual_batch_size] + ).all(), f"Topk indices are not the same for conditional and unconditional codes" # replace masks of the top-k confident codebooks with the codes that were sampled for them unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices) codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) - + # build transformer input local_transformer_input = local_transformer_input_init for codebook_num in range(codebook_seq_len): - next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze(1) # (B, 1, 768) - next_local_transformer_input = self.local_transformer_in_projection(next_local_transformer_input) # (B, 1, d_local) - local_transformer_input = torch.cat([local_transformer_input, next_local_transformer_input], dim=1) # (B, codebook_num+1, d_local) + next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze( + 1 + ) # (B, 1, 768) + next_local_transformer_input = self.local_transformer_in_projection( + next_local_transformer_input + ) # (B, 1, d_local) + local_transformer_input = torch.cat( + [local_transformer_input, next_local_transformer_input], dim=1 + ) # (B, codebook_num+1, d_local) # run transformer - _mask = torch.ones(B, codebook_seq_len+1, device=device) - local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, C+1, d_local) + _mask = torch.ones(B, codebook_seq_len + 1, device=device) + local_transformer_output = self.local_transformer(local_transformer_input, _mask)[ + 'output' + ] # (B, C+1, d_local) # get logits logits = [] for codebook_num in range(codebook_seq_len): # The `codebook_num+1` is to drop first position which corresponds to the magpie latent - codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, codebook_num+1, :]) # (B, num_audio_tokens_per_codebook) + codebook_logits = self.local_transformer_out_projections[codebook_num]( + local_transformer_output[:, codebook_num + 1, :] + ) # (B, num_audio_tokens_per_codebook) logits.append(codebook_logits) - logits = torch.stack(logits, dim=1) # (B, C*frame_stacking_factor, num_audio_tokens_per_codebook) + logits = torch.stack(logits, dim=1) # (B, C*frame_stacking_factor, num_audio_tokens_per_codebook) # apply CFG if use_cfg: @@ -784,12 +843,12 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, current_cfg_scale = cfg_scale else: # gradually increase the scale until mid point through sampling, then reduce it again - progress = step / (n_steps-1) - #interp = -abs(progress-0.5)+0.5 # increase from 0..1 in the interval from start to midpoint and then go back to zero - #interp = 1.0 - progress # decrease from 1 to 0 - interp = progress # gradually increase from 0 to 1 + progress = step / (n_steps - 1) + # interp = -abs(progress-0.5)+0.5 # increase from 0..1 in the interval from start to midpoint and then go back to zero + # interp = 1.0 - progress # decrease from 1 to 0 + interp = progress # gradually increase from 0 to 1 current_cfg_scale = (cfg_scale - 1) * interp + 1.0 # 1.0 --> cfg_scale --> 1.0 - cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits + cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits logits[:actual_batch_size] = cfg_logits # Disallow generation of special tokens (except audio EOS which is handled separately) @@ -803,12 +862,12 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, logits[item_idx, :, self.audio_eos_id] = 0.0 # sample with top-k - logits_topk = torch.topk(logits, topk, dim=-1)[0] # (B, C, topk) - indices_to_remove = logits < logits_topk[:, :, -1].unsqueeze(-1) # (B, C, num_audio_tokens_per_codebook) + logits_topk = torch.topk(logits, topk, dim=-1)[0] # (B, C, topk) + indices_to_remove = logits < logits_topk[:, :, -1].unsqueeze(-1) # (B, C, num_audio_tokens_per_codebook) logits_rescored = logits.clone() logits_rescored[indices_to_remove] = float('-inf') - probs = torch.softmax(logits_rescored / temperature, dim=-1) # (B, C, num_audio_tokens_per_codebook) - sampled_codes = torch.multinomial(probs.view(B*codebook_seq_len, -1), 1).view(B, codebook_seq_len) + probs = torch.softmax(logits_rescored / temperature, dim=-1) # (B, C, num_audio_tokens_per_codebook) + sampled_codes = torch.multinomial(probs.view(B * codebook_seq_len, -1), 1).view(B, codebook_seq_len) if use_cfg: sampled_codes[actual_batch_size:] = sampled_codes[:actual_batch_size] probs[actual_batch_size:] = probs[:actual_batch_size] @@ -823,40 +882,64 @@ def local_transformer_sample_maskgit(self, dec_output, temperature=0.7, topk=80, if noise_scale > 0.0: # get noise from uniform distribution in the interval [-0.5, 0.5), scale it by `noise_scale`, # and anneal it to 0 as we approach the end of the unmasking process - noise = (torch.rand_like(confidences) - 0.5) * noise_scale * (1-(step+2)/n_steps) # the +2 makes sure that by the last iteration the noise is exactly 0 + noise = ( + (torch.rand_like(confidences) - 0.5) * noise_scale * (1 - (step + 2) / n_steps) + ) # the +2 makes sure that by the last iteration the noise is exactly 0 confidences += noise # the conditional and unconditional get different noise and must be fixed to be the same again - confidences[actual_batch_size:] = confidences[:actual_batch_size] + confidences[actual_batch_size:] = confidences[:actual_batch_size] confidence_eps = 0.1 - assert confidences.max() + confidence_eps < max_confidence, f"Predicted confidence is approaching max_confidence: {confidences.max()}" + assert ( + confidences.max() + confidence_eps < max_confidence + ), f"Predicted confidence is approaching max_confidence: {confidences.max()}" # for unmasked codebooks, set confidence to max so that they will remain unmasked - confidences.scatter_(index=topk_indices, dim=1, src=max_confidence*torch.ones_like(topk_indices, dtype=torch.float)) + confidences.scatter_( + index=topk_indices, dim=1, src=max_confidence * torch.ones_like(topk_indices, dtype=torch.float) + ) codes = sampled_codes - assert not (codes == self.mask_token_id).any(), f"Codes contain mask tokens after completion of MaskGit sampling" + assert not ( + codes == self.mask_token_id + ).any(), f"Codes contain mask tokens after completion of MaskGit sampling" # break stacked groups of frames into individual frames - codes = codes.reshape(B, self.frame_stacking_factor, self.num_audio_codebooks).permute(0,2,1) # B, C, frame_stacking_factor + codes = codes.reshape(B, self.frame_stacking_factor, self.num_audio_codebooks).permute( + 0, 2, 1 + ) # B, C, frame_stacking_factor - if use_cfg: + if use_cfg: # drop unconditional codes codes = codes[:actual_batch_size] return codes - def local_transformer_sample_autoregressive(self, dec_output, temperature=0.7, topk=80, unfinished_items={}, finished_items={}, use_cfg=False, cfg_scale=1.0, use_kv_cache=True): + def local_transformer_sample_autoregressive( + self, + dec_output, + temperature=0.7, + topk=80, + unfinished_items={}, + finished_items={}, + use_cfg=False, + cfg_scale=1.0, + use_kv_cache=True, + ): # dec_output: (B, E) self.local_transformer.reset_cache(use_cache=use_kv_cache) - dec_output = dec_output.unsqueeze(1) # (B, 1, E) - local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) + dec_output = dec_output.unsqueeze(1) # (B, 1, E) + local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) all_preds = [] for codebook_num in range(self.num_audio_codebooks * self.frame_stacking_factor): - _mask = torch.ones( local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device) - local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, T, 128) - codebook_logits = self.local_transformer_out_projections[codebook_num](local_transformer_output[:, -1, :]) # (B, num_all_tokens_per_codebook) + _mask = torch.ones( + local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device + ) + local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, T, 128) + codebook_logits = self.local_transformer_out_projections[codebook_num]( + local_transformer_output[:, -1, :] + ) # (B, num_all_tokens_per_codebook) if use_cfg: actual_batch_size = codebook_logits.size(0) // 2 conditional_logits = codebook_logits[:actual_batch_size] unconditional_logits = codebook_logits[actual_batch_size:] - cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits + cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits codebook_logits[:actual_batch_size] = cfg_logits for item_idx in unfinished_items: @@ -866,27 +949,41 @@ def local_transformer_sample_autoregressive(self, dec_output, temperature=0.7, t codebook_logits[item_idx, self.audio_eos_id] = 0.0 codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1) - codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) - indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(-1) # (B, num_tokens_per_codebook) + codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) + indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( + -1 + ) # (B, num_tokens_per_codebook) codebook_logits_rescored = codebook_logits.clone() codebook_logits_rescored[indices_to_remove] = float('-inf') - codebook_probs = torch.softmax(codebook_logits_rescored / temperature, dim=-1) # (B, num_tokens_per_codebook) - codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) + codebook_probs = torch.softmax( + codebook_logits_rescored / temperature, dim=-1 + ) # (B, num_tokens_per_codebook) + codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) if use_cfg: codebook_preds[actual_batch_size:] = codebook_preds[:actual_batch_size] all_preds.append(codebook_preds) - next_local_transformer_input = self.audio_embeddings[codebook_num](codebook_preds.squeeze(-1)).unsqueeze(1) # (B, 1, 128) - next_local_transformer_input = self.local_transformer_in_projection(next_local_transformer_input) # (B, 1, 128) - local_transformer_input = torch.cat([local_transformer_input, next_local_transformer_input], dim=1) # (B, T+1, 128) - - all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks * frame_stacking_factor) - all_preds = all_preds.reshape(-1, self.frame_stacking_factor, self.num_audio_codebooks).permute(0,2,1) # (B, num_codebooks, frame_stacking_factor) + next_local_transformer_input = self.audio_embeddings[codebook_num](codebook_preds.squeeze(-1)).unsqueeze( + 1 + ) # (B, 1, 128) + next_local_transformer_input = self.local_transformer_in_projection( + next_local_transformer_input + ) # (B, 1, 128) + local_transformer_input = torch.cat( + [local_transformer_input, next_local_transformer_input], dim=1 + ) # (B, T+1, 128) + + all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks * frame_stacking_factor) + all_preds = all_preds.reshape(-1, self.frame_stacking_factor, self.num_audio_codebooks).permute( + 0, 2, 1 + ) # (B, num_codebooks, frame_stacking_factor) if use_cfg: all_preds = all_preds[:actual_batch_size] return all_preds - def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={}): + def sample_codes_from_logits( + self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={} + ): # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep all_preds = [[] for _ in range(self.frame_stacking_factor)] for fs_index in range(self.frame_stacking_factor): @@ -908,11 +1005,15 @@ def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80, codebook_logits_rescored = codebook_logits.clone() codebook_logits_rescored[indices_to_remove] = float('-inf') - codebook_probs = torch.softmax(codebook_logits_rescored / temperature, dim=-1) # (B, num_tokens_per_codebook) + codebook_probs = torch.softmax( + codebook_logits_rescored / temperature, dim=-1 + ) # (B, num_tokens_per_codebook) codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) all_preds[fs_index].append(codebook_preds) - - all_preds = [torch.cat(ds_preds, dim=1).long() for ds_preds in all_preds] # list of `frame_stacking_factor` elements, each of shape (B, num_codebooks) + + all_preds = [ + torch.cat(ds_preds, dim=1).long() for ds_preds in all_preds + ] # list of `frame_stacking_factor` elements, each of shape (B, num_codebooks) all_preds = torch.stack(all_preds, dim=2) # (B, num_codebooks, frame_stacking_factor) return all_preds @@ -928,7 +1029,9 @@ def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens is_wandb = isinstance(logger, WandbLogger) is_tb = isinstance(logger, TensorBoardLogger) if not is_wandb and not is_tb: - raise ValueError(f"Invalid logger type for image logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported.") + raise ValueError( + f"Invalid logger type for image logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." + ) wandb_images_log[f"Image/{prefix}/attention_matrix"] = list() for idx in range(min(3, attention_prob_matrix_mean.size(0))): @@ -939,7 +1042,9 @@ def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens img_np = plot_alignment_to_numpy(item_attn_matrix.T) if is_wandb: - wandb_images_log[f"Image/{prefix}/attention_matrix"].append(wandb.Image(img_np, caption=f"Example_{idx}")) + wandb_images_log[f"Image/{prefix}/attention_matrix"].append( + wandb.Image(img_np, caption=f"Example_{idx}") + ) if is_tb: logger.experiment.add_image( @@ -974,7 +1079,9 @@ def log_val_audio_example( is_wandb = isinstance(logger, WandbLogger) is_tb = isinstance(logger, TensorBoardLogger) if not is_wandb and not is_tb: - raise ValueError(f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported.") + raise ValueError( + f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." + ) for idx in range(min(3, pred_audio.size(0))): pred_audio_np = pred_audio[idx].float().detach().cpu().numpy() @@ -989,9 +1096,15 @@ def log_val_audio_example( if is_wandb: wandb_audio_log[f"Audio/Example_{idx}"] = list() if context_audio_np is not None: - wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(context_audio_np, sample_rate=self.sample_rate, caption="context")) - wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(pred_audio_np, sample_rate=self.sample_rate, caption="prediction")) - wandb_audio_log[f"Audio/Example_{idx}"].append(wandb.Audio(target_audio_np, sample_rate=self.sample_rate, caption="target")) + wandb_audio_log[f"Audio/Example_{idx}"].append( + wandb.Audio(context_audio_np, sample_rate=self.sample_rate, caption="context") + ) + wandb_audio_log[f"Audio/Example_{idx}"].append( + wandb.Audio(pred_audio_np, sample_rate=self.sample_rate, caption="prediction") + ) + wandb_audio_log[f"Audio/Example_{idx}"].append( + wandb.Audio(target_audio_np, sample_rate=self.sample_rate, caption="target") + ) if is_tb: if context_audio_np is not None: @@ -1061,7 +1174,7 @@ def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_co ) return alignment_loss - def pad_audio_codes(self, audio_codes: torch.Tensor, frame_stacking_factor: int = 1, pad_token: int =0): + def pad_audio_codes(self, audio_codes: torch.Tensor, frame_stacking_factor: int = 1, pad_token: int = 0): """ Pads the time dimension of the audio codes to a multiple of the frame stacking factor. Args: @@ -1074,17 +1187,25 @@ def pad_audio_codes(self, audio_codes: torch.Tensor, frame_stacking_factor: int T = audio_codes.size(2) T_padded = int(np.ceil(T / frame_stacking_factor) * frame_stacking_factor) if T_padded > T: - padding = pad_token * torch.ones(audio_codes.size(0), audio_codes.size(1), T_padded - T, device=audio_codes.device, dtype=audio_codes.dtype) + padding = pad_token * torch.ones( + audio_codes.size(0), + audio_codes.size(1), + T_padded - T, + device=audio_codes.device, + dtype=audio_codes.dtype, + ) audio_codes = torch.cat([audio_codes, padding], dim=2) return audio_codes def embed_context_text(self, context_text_tokens): if self.legacy_text_conditioning: - context_text_tokens = context_text_tokens - self.tokenizer.tokenizer_offsets[self.text_conditioning_tokenizer_name] + context_text_tokens = ( + context_text_tokens - self.tokenizer.tokenizer_offsets[self.text_conditioning_tokenizer_name] + ) context_text_embedded = self.context_text_embedding(context_text_tokens) # (B, L, E) else: context_text_embedded = self.text_embedding(context_text_tokens) # (B, L, E) - + return context_text_embedded def prepare_context_tensors(self, batch): @@ -1136,7 +1257,7 @@ def prepare_context_tensors(self, batch): context_text_tokens = batch['context_text_tokens'] context_text_lens = batch['context_text_tokens_lens'] context_text_embedded = self.embed_context_text(context_text_tokens) # (B, L, E) - + # Pad context_audio_embedded or context_text_embedded so that they have same number of timesteps if context_audio_embedded.size(1) < context_text_embedded.size(1): padding = torch.zeros( @@ -1165,7 +1286,9 @@ def prepare_context_tensors(self, batch): else: context_input_embedded = context_audio_embedded context_input_lens = context_audio_codes_lens - context_input_lens = torch.ceil(context_input_lens / self.frame_stacking_factor).to(context_input_lens.dtype) + context_input_lens = torch.ceil(context_input_lens / self.frame_stacking_factor).to( + context_input_lens.dtype + ) context_mask = get_mask_from_lengths(context_input_lens) @@ -1207,10 +1330,16 @@ def prepare_context_tensors(self, batch): # Convert prior to a list of tensors, one for each layer # Set None for layers not in ctc_prior_layer_ids if self.model_type == 'multi_encoder_context_tts': - text_attn_prior = [attn_prior[0] if layer_idx in self.ctc_prior_layer_ids else None for layer_idx in range(self.decoder.n_layers) ] + text_attn_prior = [ + attn_prior[0] if layer_idx in self.ctc_prior_layer_ids else None + for layer_idx in range(self.decoder.n_layers) + ] attn_prior = [text_attn_prior, attn_prior[1]] else: - attn_prior = [attn_prior if layer_idx in self.ctc_prior_layer_ids else None for layer_idx in range(self.decoder.n_layers) ] + attn_prior = [ + attn_prior if layer_idx in self.ctc_prior_layer_ids else None + for layer_idx in range(self.decoder.n_layers) + ] return { 'beta_binomial_attn_prior': batch.get('align_prior_matrix', None), @@ -1245,12 +1374,12 @@ def replace_beta_binomial_prior_with_binarized(self, attn_prior, aligner_attn_ha prior_updated = False for idx, prior in enumerate(text_attn_prior): if prior is not None: - text_attn_prior[idx][:,-aligner_attn_hard.size(1):,:] = aligner_attn_hard + text_attn_prior[idx][:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard prior_updated = True assert prior_updated, "Did not find any prior to update" else: # Same prior for all layers - text_attn_prior[:,-aligner_attn_hard.size(1):,:] = aligner_attn_hard + text_attn_prior[:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard if self.model_type == 'multi_encoder_context_tts': attn_prior[0] = text_attn_prior @@ -1264,25 +1393,39 @@ def get_binarized_prior_matrix(self, aligner_attn_soft, audio_lens, text_lens): if self.binarize_attn_method == 'nemo_binarize': logging.debug("Binarizing attention using nemo_binarize") binarize_repeat_audio_factor = self.binarize_repeat_audio_factor - aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave(binarize_repeat_audio_factor, dim=2) # B, 1, 2*audio_timesteps, text_timesteps - aligner_attn_hard = binarize_attention_parallel(aligner_attn_soft_repeated, text_lens, audio_lens*binarize_repeat_audio_factor).squeeze(1) # B, 2*audio_timesteps, text_timesteps - aligner_attn_hard = aligner_attn_hard[:, ::2, :] # B, audio_timesteps, text_timesteps + aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave( + binarize_repeat_audio_factor, dim=2 + ) # B, 1, 2*audio_timesteps, text_timesteps + aligner_attn_hard = binarize_attention_parallel( + aligner_attn_soft_repeated, text_lens, audio_lens * binarize_repeat_audio_factor + ).squeeze( + 1 + ) # B, 2*audio_timesteps, text_timesteps + aligner_attn_hard = aligner_attn_hard[:, ::2, :] # B, audio_timesteps, text_timesteps elif self.binarize_attn_method == 'argmax': logging.debug("Binarizing attention using argmax") aligner_attn_hard = torch.argmax(aligner_attn_soft.squeeze(1), dim=-1) - aligner_attn_hard = torch.nn.functional.one_hot(aligner_attn_hard, num_classes=aligner_attn_soft.size(-1)).float() + aligner_attn_hard = torch.nn.functional.one_hot( + aligner_attn_hard, num_classes=aligner_attn_soft.size(-1) + ).float() else: - raise ValueError(f"self.binarize_attn_method '{self.binarize_attn_method}' must be one of 'nemo_binarize' or 'argmax'.") + raise ValueError( + f"self.binarize_attn_method '{self.binarize_attn_method}' must be one of 'nemo_binarize' or 'argmax'." + ) aligner_attn_hard_wider = aligner_attn_hard + self.binarized_prior_epsilon for future_timestep in range(self.prior_future_context): decay_factor = self.prior_future_decay ** (future_timestep + 1) - aligner_attn_hard_wider[:,:,future_timestep+1:] += decay_factor * aligner_attn_hard[:,:,:-(future_timestep+1)] + aligner_attn_hard_wider[:, :, future_timestep + 1 :] += ( + decay_factor * aligner_attn_hard[:, :, : -(future_timestep + 1)] + ) for past_timestep in range(self.prior_past_context): decay_factor = self.prior_past_decay ** (past_timestep + 1) - aligner_attn_hard_wider[:,:,:-past_timestep-1] += decay_factor * aligner_attn_hard[:,:,past_timestep+1:] + aligner_attn_hard_wider[:, :, : -past_timestep - 1] += ( + decay_factor * aligner_attn_hard[:, :, past_timestep + 1 :] + ) aligner_attn_hard_wider = torch.clamp(aligner_attn_hard_wider, 0.0, 1.0) return aligner_attn_hard_wider @@ -1329,28 +1472,39 @@ def process_batch(self, batch, mode="train"): # repeat the BOS token to frame_stacking_factor times. This is necessary since at inference # we need to start autoregressive generation from a full stack indicating BOS. # TODO: @rfejgin: this assert might be slow due to GPU/CPU sync - assert (audio_codes[:,:,0] == self.audio_bos_id).all(), "Audio codes do not start with BOS token" - audio_codes = torch.cat([torch.full((audio_codes.size(0), audio_codes.size(1), self.frame_stacking_factor - 1), self.audio_bos_id, device=audio_codes.device, dtype=audio_codes.dtype), audio_codes], dim=2) - audio_codes_lens += self.frame_stacking_factor - 1 # account for BOS repeat + assert (audio_codes[:, :, 0] == self.audio_bos_id).all(), "Audio codes do not start with BOS token" + audio_codes = torch.cat( + [ + torch.full( + (audio_codes.size(0), audio_codes.size(1), self.frame_stacking_factor - 1), + self.audio_bos_id, + device=audio_codes.device, + dtype=audio_codes.dtype, + ), + audio_codes, + ], + dim=2, + ) + audio_codes_lens += self.frame_stacking_factor - 1 # account for BOS repeat audio_codes = self.pad_audio_codes(audio_codes, self.frame_stacking_factor, pad_token=0) # Note: if a tensor lacks the `_unstacked` suffix, it can be assumed to to be in the frame-stacked domain - + # drop last (stacked) frame since it is not part of *input* - audio_codes_input_unstacked = audio_codes[:, :, :-self.frame_stacking_factor] # B, C, T' + audio_codes_input_unstacked = audio_codes[:, :, : -self.frame_stacking_factor] # B, C, T' # drop first (stacked) frame which contains BOS token(s) which are not part of *target* - audio_codes_target_unstacked = audio_codes[:, :, self.frame_stacking_factor:] - audio_codes_lens_input_unstacked = audio_codes_lens - 1 # don't count EOS for input - audio_codes_lens_target_unstacked = audio_codes_lens - self.frame_stacking_factor # don't count BOS for target + audio_codes_target_unstacked = audio_codes[:, :, self.frame_stacking_factor :] + audio_codes_lens_input_unstacked = audio_codes_lens - 1 # don't count EOS for input + audio_codes_lens_target_unstacked = audio_codes_lens - self.frame_stacking_factor # don't count BOS for target audio_codes_lens_input = torch.floor(audio_codes_lens_input_unstacked / self.frame_stacking_factor).long() - audio_codes_embedded_all = self.embed_audio_tokens(audio_codes) # (B, T, E) # Computing this to be use in the alignment encoder - audio_codes_embedded = audio_codes_embedded_all[:, :-1, :] # (B, T', E) Input to the decoder; this is already in the frame-stacked domain, hence the -1 (not `frame_stacking_factor`) + audio_codes_embedded_all = self.embed_audio_tokens( + audio_codes + ) # (B, T, E) # Computing this to be use in the alignment encoder + audio_codes_embedded = audio_codes_embedded_all[ + :, :-1, : + ] # (B, T', E) Input to the decoder; this is already in the frame-stacked domain, hence the -1 (not `frame_stacking_factor`) audio_codes_mask = get_mask_from_lengths(audio_codes_lens_input) - use_cfg = ( - (self.cfg_unconditional_prob > 0.0) - and (mode == "train") - and (context_tensors['cond'] is not None) - ) + use_cfg = (self.cfg_unconditional_prob > 0.0) and (mode == "train") and (context_tensors['cond'] is not None) if use_cfg and torch.rand(1).item() < self.cfg_unconditional_prob: cond, cond_mask, additional_decoder_input, additional_decoder_mask, attn_prior = ( self.prepare_dummy_cond_for_cfg( @@ -1368,11 +1522,7 @@ def process_batch(self, batch, mode="train"): additional_decoder_mask = context_tensors['additional_decoder_mask'] attn_prior = context_tensors['attn_prior'] - if ( - mode == "train" - and self.decoder_input_dropout_prob > 0.0 - and torch.rand(1).item() < 0.5 - ): + if mode == "train" and self.decoder_input_dropout_prob > 0.0 and torch.rand(1).item() < 0.5: # For some batches (half of them), replace decoder_input_dropout_prob of the timesteps with random tokens max_codebook_val = self.dec_random_input_max # @pneekhara: Keeping dec_random_input_max configurable since num_all_tokens_per_codebook usually has padding tokens @@ -1387,8 +1537,10 @@ def process_batch(self, batch, mode="train"): > self.decoder_input_dropout_prob ) # timestep_mask is True for timesteps to be kept - audio_codes_input_unstacked = audio_codes_input_unstacked * dec_dropout_mask + random_audio_tokens * (~dec_dropout_mask) - audio_codes_embedded = self.embed_audio_tokens(audio_codes_input_unstacked) # (B, T', E) + audio_codes_input_unstacked = audio_codes_input_unstacked * dec_dropout_mask + random_audio_tokens * ( + ~dec_dropout_mask + ) + audio_codes_embedded = self.embed_audio_tokens(audio_codes_input_unstacked) # (B, T', E) if context_tensors['additional_decoder_input'] is not None: dec_input_embedded = torch.cat([additional_decoder_input, audio_codes_embedded], dim=1) @@ -1407,23 +1559,25 @@ def process_batch(self, batch, mode="train"): # Passing target audio embeddings to the alignment encoder if self.global_step < self.aligner_encoder_train_steps: aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder( - queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' - keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T + queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' + keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T mask=~context_tensors['text_mask'].unsqueeze(-1), - attn_prior=aligner_prior + attn_prior=aligner_prior, ) aligner_encoder_loss = self.alignment_encoder_loss( - attn_logprob=aligner_attn_logprobs, in_lens=context_tensors['text_lens'], out_lens=audio_codes_lens_input + attn_logprob=aligner_attn_logprobs, + in_lens=context_tensors['text_lens'], + out_lens=audio_codes_lens_input, ) else: with torch.no_grad(): # Just get the attention matrix without computing the loss or gradients aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder( - queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' - keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T + queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' + keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T mask=~context_tensors['text_mask'].unsqueeze(-1), - attn_prior=aligner_prior + attn_prior=aligner_prior, ) with torch.no_grad(): @@ -1448,15 +1602,20 @@ def process_batch(self, batch, mode="train"): # Codebook loss (parallel) codebook_loss, loss_mask = self.compute_loss( - logits, audio_codes_target_unstacked, + logits, + audio_codes_target_unstacked, audio_codes_lens_target_unstacked, - frame_stacking_factor=self.frame_stacking_factor + frame_stacking_factor=self.frame_stacking_factor, ) # Alignment loss alignment_loss = None if self.alignment_loss_scale > 0.0 and not disable_alignment_loss: text_lens = context_tensors['text_lens'] - cross_attention_scores = [attn['cross_attn_probabilities'][1] for layer_idx, attn in enumerate(attn_info) if layer_idx in self.ctc_prior_layer_ids] + cross_attention_scores = [ + attn['cross_attn_probabilities'][1] + for layer_idx, attn in enumerate(attn_info) + if layer_idx in self.ctc_prior_layer_ids + ] alignment_loss = self.compute_alignment_loss( cross_attention_scores, text_lens, audio_codes_lens_input, dec_context_size ) @@ -1474,13 +1633,29 @@ def process_batch(self, batch, mode="train"): audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target_unstacked) # TODO @rfejgin: the very last position might be padding but the local transformer might look at it as part of # of a pair where the first position is valid. Is this an issue? - local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_masked, targets_offset_by_one=True) - local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target_unstacked, audio_codes_lens_target_unstacked, mask_tokens_mask, frame_stacking_factor=self.frame_stacking_factor) + local_transformer_logits = self.compute_local_transformer_logits( + dec_out[:, dec_context_size:, :], audio_codes_masked, targets_offset_by_one=True + ) + local_transformer_loss, _ = self.compute_loss( + local_transformer_logits, + audio_codes_target_unstacked, + audio_codes_lens_target_unstacked, + mask_tokens_mask, + frame_stacking_factor=self.frame_stacking_factor, + ) else: ## Autoregressive ## assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type" - local_transformer_logits = self.compute_local_transformer_logits(dec_out[:,dec_context_size:,:], audio_codes_target_unstacked, targets_offset_by_one=False) - local_transformer_loss, _ = self.compute_loss(local_transformer_logits, audio_codes_target_unstacked, audio_codes_lens_target_unstacked, None, frame_stacking_factor=self.frame_stacking_factor) + local_transformer_logits = self.compute_local_transformer_logits( + dec_out[:, dec_context_size:, :], audio_codes_target_unstacked, targets_offset_by_one=False + ) + local_transformer_loss, _ = self.compute_loss( + local_transformer_logits, + audio_codes_target_unstacked, + audio_codes_lens_target_unstacked, + None, + frame_stacking_factor=self.frame_stacking_factor, + ) loss = loss + self.local_transformer_loss_scale * local_transformer_loss if aligner_encoder_loss is not None: @@ -1491,8 +1666,8 @@ def process_batch(self, batch, mode="train"): 'attn_info': attn_info, 'loss': loss, 'codebook_loss': codebook_loss, - 'local_transformer_loss' : local_transformer_loss, - 'local_transformer_logits' : local_transformer_logits, + 'local_transformer_loss': local_transformer_loss, + 'local_transformer_logits': local_transformer_logits, 'loss_mask': loss_mask, 'alignment_loss': alignment_loss, 'aligner_encoder_loss': aligner_encoder_loss, @@ -1530,25 +1705,32 @@ def training_step(self, batch, batch_idx): "train/batch_size": batch_size, "train/text_token_max_len": text_token_max_len, "train/text_token_total_num_in_batch": text_token_total_num.item(), - "train/text_token_pad_ratio_percent_in_batch": 100 * (1 - text_token_total_num / (batch_size * text_token_max_len)), + "train/text_token_pad_ratio_percent_in_batch": 100 + * (1 - text_token_total_num / (batch_size * text_token_max_len)), } if "audio_codes" in batch: audio_codes_max_len = batch["audio_codes"].shape[-1] audio_codes_total_num = batch["audio_codes_lens"].sum() - batch_info_dict.update({ - "train/audio_codes_max_len": audio_codes_max_len, - "train/audio_codes_total_num_in_batch": audio_codes_total_num.item(), - "train/audio_codes_pad_ratio_percent_in_batch": 100 * (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)), - }) + batch_info_dict.update( + { + "train/audio_codes_max_len": audio_codes_max_len, + "train/audio_codes_total_num_in_batch": audio_codes_total_num.item(), + "train/audio_codes_pad_ratio_percent_in_batch": 100 + * (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)), + } + ) else: audio_samples_max_len = batch["audio"].shape[-1] audio_samples_total_num = batch["audio_lens"].sum() - batch_info_dict.update({ - "train/audio_samples_max_len": audio_samples_max_len, - "train/audio_samples_total_num_in_batch": audio_samples_total_num.item(), - "train/audio_samples_pad_ratio_percent_in_batch": 100 * (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)), - }) + batch_info_dict.update( + { + "train/audio_samples_max_len": audio_samples_max_len, + "train/audio_samples_total_num_in_batch": audio_samples_total_num.item(), + "train/audio_samples_pad_ratio_percent_in_batch": 100 + * (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)), + } + ) self.log_dict(batch_info_dict, on_step=True) @@ -1592,7 +1774,11 @@ def validation_step(self, batch, batch_idx): and len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1 ): # cross_attn_probabilities only returned when not using flash attention - cross_attention_probs = [attn['cross_attn_probabilities'][0] for layer_idx, attn in enumerate(attn_info) if layer_idx in self.ctc_prior_layer_ids] + cross_attention_probs = [ + attn['cross_attn_probabilities'][0] + for layer_idx, attn in enumerate(attn_info) + if layer_idx in self.ctc_prior_layer_ids + ] wandb_log_dict.update( self.log_attention_probs( cross_attention_probs, @@ -1604,14 +1790,14 @@ def validation_step(self, batch, batch_idx): ) for layer_idx in self.transcript_decoder_layers: - cross_attention_probs = [ attn_info[layer_idx]['cross_attn_probabilities'][0] ] + cross_attention_probs = [attn_info[layer_idx]['cross_attn_probabilities'][0]] wandb_log_dict.update( self.log_attention_probs( cross_attention_probs, audio_codes_lens_target, text_lens, prefix=f"val/layer_{layer_idx}", - dec_context_size=dec_context_size + dec_context_size=dec_context_size, ) ) @@ -1659,20 +1845,31 @@ def get_cross_attention_scores(self, attn_probs, filter_layers=None): mean_cross_attn_scores = [] all_heads_cross_attn_scores = [] for lidx, layerwise_attn_prob in enumerate(attn_probs): - if (filter_layers is not None and lidx not in filter_layers) or (lidx not in self.transcript_decoder_layers): + if (filter_layers is not None and lidx not in filter_layers) or ( + lidx not in self.transcript_decoder_layers + ): continue - cross_attn_prob = layerwise_attn_prob['cross_attn_probabilities'][0] # B, H, audio_timesteps, text_timesteps - mean_cross_attn_scores.append(cross_attn_prob.mean(dim=1)) # B, audio_timesteps, text_timesteps + cross_attn_prob = layerwise_attn_prob['cross_attn_probabilities'][ + 0 + ] # B, H, audio_timesteps, text_timesteps + mean_cross_attn_scores.append(cross_attn_prob.mean(dim=1)) # B, audio_timesteps, text_timesteps for head_idx in range(cross_attn_prob.size(1)): - all_heads_cross_attn_scores.append(cross_attn_prob[:, head_idx, -1, :]) # B, text_timesteps + all_heads_cross_attn_scores.append(cross_attn_prob[:, head_idx, -1, :]) # B, text_timesteps - mean_cross_attn_scores = torch.stack(mean_cross_attn_scores, dim=1) # B, L, audio_timesteps, text_timesteps - mean_cross_attn_scores = mean_cross_attn_scores.mean(dim=1) # B, audio_timesteps, text_timesteps - last_audio_timestep_scores = mean_cross_attn_scores[:, -1, :] # B, text_timesteps + mean_cross_attn_scores = torch.stack(mean_cross_attn_scores, dim=1) # B, L, audio_timesteps, text_timesteps + mean_cross_attn_scores = mean_cross_attn_scores.mean(dim=1) # B, audio_timesteps, text_timesteps + last_audio_timestep_scores = mean_cross_attn_scores[:, -1, :] # B, text_timesteps return last_audio_timestep_scores, all_heads_cross_attn_scores - def get_most_attended_text_timestep(self, alignment_attention_scores, last_attended_timesteps, - text_lens, lookahead_window_size, attended_timestep_counter, batch_size): + def get_most_attended_text_timestep( + self, + alignment_attention_scores, + last_attended_timesteps, + text_lens, + lookahead_window_size, + attended_timestep_counter, + batch_size, + ): """ Returns the most attended timestep for each batch item """ @@ -1683,20 +1880,32 @@ def get_most_attended_text_timestep(self, alignment_attention_scores, last_atten # This is probably an attention sink! Move to the next timestep last_attended_timestep += 1 window_size = lookahead_window_size - window_end = min(last_attended_timestep + window_size, text_lens[bidx] - 3) # Ignore the last 3 timesteps - item_attention_scores = alignment_attention_scores[bidx,last_attended_timestep:window_end] + window_end = min(last_attended_timestep + window_size, text_lens[bidx] - 3) # Ignore the last 3 timesteps + item_attention_scores = alignment_attention_scores[bidx, last_attended_timestep:window_end] if item_attention_scores.size(0) == 0: # This means the sentence has ended attended_timestep = text_lens[bidx] - 1 else: attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep text_time_step_attended.append(attended_timestep) - attended_timestep_counter[bidx][attended_timestep] = attended_timestep_counter[bidx].get(attended_timestep, 0) + 1 + attended_timestep_counter[bidx][attended_timestep] = ( + attended_timestep_counter[bidx].get(attended_timestep, 0) + 1 + ) return text_time_step_attended, attended_timestep_counter - def construct_inference_prior(self, prior_epsilon, cross_attention_scores, - text_lens, text_time_step_attended, attended_timestep_counter, - unfinished_texts, finished_texts_counter, end_indices, lookahead_window_size, batch_size): + def construct_inference_prior( + self, + prior_epsilon, + cross_attention_scores, + text_lens, + text_time_step_attended, + attended_timestep_counter, + unfinished_texts, + finished_texts_counter, + end_indices, + lookahead_window_size, + batch_size, + ): # Attn prior for the next timestep _attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + prior_epsilon _attn_prior = _attn_prior.to(cross_attention_scores.device) @@ -1707,16 +1916,20 @@ def construct_inference_prior(self, prior_epsilon, cross_attention_scores, # Very short sentences, No Prior _attn_prior[bidx, 0, :] = 1.0 else: - _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx]-1)] = 1.0 # Slight exposure to history for better pronounciation. Not very important. - _attn_prior[bidx, 0, text_time_step_attended[bidx]] = 1.0 # Slightly bias to continue moving forward. Not very important. + _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx] - 1)] = ( + 1.0 # Slight exposure to history for better pronounciation. Not very important. + ) + _attn_prior[bidx, 0, text_time_step_attended[bidx]] = ( + 1.0 # Slightly bias to continue moving forward. Not very important. + ) for ind in range(1, lookahead_window_size + 1): - _attn_prior[bidx, 0, min(text_time_step_attended[bidx]+ind, _text_len - 1) ] = 1.0 + _attn_prior[bidx, 0, min(text_time_step_attended[bidx] + ind, _text_len - 1)] = 1.0 # Penalize timesteps that have been attended to more than 10 times for _timestep in attended_timestep_counter[bidx]: if attended_timestep_counter[bidx][_timestep] >= 10: # This means the timestep has been attended to more than 10 times (To avoid getting stuck) - _attn_prior[bidx, 0, :_timestep+1] = prior_epsilon + _attn_prior[bidx, 0, : _timestep + 1] = prior_epsilon unfinished_texts[bidx] = False if text_time_step_attended[bidx] < text_lens[bidx] - 3: @@ -1737,24 +1950,42 @@ def construct_inference_prior(self, prior_epsilon, cross_attention_scores, return _attn_prior, unfinished_texts, finished_texts_counter - def get_inference_attention_plots(self, cross_attention_scores_all_timesteps, all_heads_cross_attn_scores_all_timesteps, text_lens, predicted_codes_lens, batch_size, compute_all_heads_attn_maps): - cross_attention_scores_all_timesteps = torch.stack(cross_attention_scores_all_timesteps, dim=2) # B, text_timesteps, T' + def get_inference_attention_plots( + self, + cross_attention_scores_all_timesteps, + all_heads_cross_attn_scores_all_timesteps, + text_lens, + predicted_codes_lens, + batch_size, + compute_all_heads_attn_maps, + ): + cross_attention_scores_all_timesteps = torch.stack( + cross_attention_scores_all_timesteps, dim=2 + ) # B, text_timesteps, T' headwise_cross_attention_scores_all_timesteps = [] for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): - head_cross_attention_all_timesteps = torch.stack([x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2) # B, text_timesteps, T' + head_cross_attention_all_timesteps = torch.stack( + [x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2 + ) # B, text_timesteps, T' headwise_cross_attention_scores_all_timesteps.append(head_cross_attention_all_timesteps) cross_attention_maps = [] headwise_cross_attention_maps = [] for bidx in range(batch_size): - item_cross_attention_scores = cross_attention_scores_all_timesteps[bidx,:text_lens[bidx],:predicted_codes_lens[bidx]] + item_cross_attention_scores = cross_attention_scores_all_timesteps[ + bidx, : text_lens[bidx], : predicted_codes_lens[bidx] + ] cross_attn_np = plot_alignment_to_numpy(item_cross_attention_scores.cpu().numpy()) cross_attention_maps.append(cross_attn_np) item_all_head_cross_attn_maps = [] if compute_all_heads_attn_maps: for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): - item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][bidx,:text_lens[bidx],:predicted_codes_lens[bidx]] - headwise_cross_attn_np = plot_alignment_to_numpy(item_headwise_cross_attention_scores.cpu().numpy()) + item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][ + bidx, : text_lens[bidx], : predicted_codes_lens[bidx] + ] + headwise_cross_attn_np = plot_alignment_to_numpy( + item_headwise_cross_attention_scores.cpu().numpy() + ) item_all_head_cross_attn_maps.append(headwise_cross_attn_np) headwise_cross_attention_maps.append(item_all_head_cross_attn_maps) @@ -1769,7 +2000,7 @@ def find_eos_frame_index(self, codes) -> Optional[int]: Returns: index (within the frame stack) of the first frame with EOS, or `None` if no EOS is found """ - eos_mask = (codes == self.audio_eos_id) # (codebooks, frame_stacking_factor) + eos_mask = codes == self.audio_eos_id # (codebooks, frame_stacking_factor) eos_per_frame = eos_mask.any(dim=0) # (frame_stacking_factor,) - True if any codebook has EOS in this frame # find first frame with EOS if eos_per_frame.any(): @@ -1778,28 +2009,29 @@ def find_eos_frame_index(self, codes) -> Optional[int]: return None def infer_batch( - self, - batch, - max_decoder_steps=500, - temperature=0.7, - topk=80, - use_cfg=False, - cfg_scale=1.0, - return_cross_attn_probs=False, - apply_attention_prior=False, - prior_epsilon=1e-5, - lookahead_window_size=10, - estimate_alignment_from_layers=None, - apply_prior_to_layers=None, - start_prior_after_n_audio_steps=10, - compute_all_heads_attn_maps=False, - use_local_transformer_for_inference=False, - use_LT_kv_cache=True, - maskgit_n_steps=3, - maskgit_noise_scale=0.0, - maskgit_fixed_schedule=None, - maskgit_dynamic_cfg_scale=False, - maskgit_sampling_type=None): + self, + batch, + max_decoder_steps=500, + temperature=0.7, + topk=80, + use_cfg=False, + cfg_scale=1.0, + return_cross_attn_probs=False, + apply_attention_prior=False, + prior_epsilon=1e-5, + lookahead_window_size=10, + estimate_alignment_from_layers=None, + apply_prior_to_layers=None, + start_prior_after_n_audio_steps=10, + compute_all_heads_attn_maps=False, + use_local_transformer_for_inference=False, + use_LT_kv_cache=True, + maskgit_n_steps=3, + maskgit_noise_scale=0.0, + maskgit_fixed_schedule=None, + maskgit_dynamic_cfg_scale=False, + maskgit_sampling_type=None, + ): with torch.no_grad(): start_time = time.time() self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference) @@ -1807,9 +2039,13 @@ def infer_batch( context_tensors = self.prepare_context_tensors(batch) text = context_tensors['text'] audio_codes_bos = torch.full( - (text.size(0), self.num_audio_codebooks, self.frame_stacking_factor), self.audio_bos_id, device=text.device + (text.size(0), self.num_audio_codebooks, self.frame_stacking_factor), + self.audio_bos_id, + device=text.device, ).long() - audio_codes_lens = torch.full((text.size(0),), 1, device=text.device).long() # intetionally 1 rather than self.frame_stacking_factor since this is in stacked form + audio_codes_lens = torch.full( + (text.size(0),), 1, device=text.device + ).long() # intetionally 1 rather than self.frame_stacking_factor since this is in stacked form audio_codes_input = audio_codes_bos audio_codes_mask = get_mask_from_lengths(audio_codes_lens) @@ -1832,7 +2068,9 @@ def infer_batch( unfinished_texts = {} finished_texts_counter = {} attended_timestep_counter = [{} for _ in range(text.size(0))] - last_attended_timesteps = [[1 for _ in range(text.size(0))]] # Maintain a list of attended timesteps as we predict audio for each batch item + last_attended_timesteps = [ + [1 for _ in range(text.size(0))] + ] # Maintain a list of attended timesteps as we predict audio for each batch item time_to_first_prediction = 0.0 for idx in range(max_decoder_steps // self.frame_stacking_factor): if idx == 1: @@ -1844,7 +2082,9 @@ def infer_batch( _audio_codes_embedded = torch.cat( [context_tensors['additional_decoder_input'], audio_codes_embedded], dim=1 ) - _audio_codes_mask = torch.cat([context_tensors['additional_decoder_mask'], audio_codes_mask], dim=1) + _audio_codes_mask = torch.cat( + [context_tensors['additional_decoder_mask'], audio_codes_mask], dim=1 + ) else: _audio_codes_embedded = audio_codes_embedded _audio_codes_mask = audio_codes_mask @@ -1859,7 +2099,6 @@ def infer_batch( if self.model_type == 'multi_encoder_context_tts': attn_prior = [attn_prior, None] - if use_cfg: batch_size = audio_codes_embedded.size(0) if isinstance(context_tensors['cond'], list): @@ -1896,7 +2135,7 @@ def infer_batch( cond=cfg_cond, cond_mask=cfg_cond_mask, attn_prior=attn_prior, - multi_encoder_mapping=context_tensors['multi_encoder_mapping'] + multi_encoder_mapping=context_tensors['multi_encoder_mapping'], ) cond_logits = combined_logits[:batch_size] @@ -1910,14 +2149,18 @@ def infer_batch( cond=context_tensors['cond'], cond_mask=context_tensors['cond_mask'], attn_prior=attn_prior, - multi_encoder_mapping=context_tensors['multi_encoder_mapping'] + multi_encoder_mapping=context_tensors['multi_encoder_mapping'], ) if return_cross_attn_probs or apply_attention_prior: - cross_attention_scores, all_heads_cross_attn_scores = self.get_cross_attention_scores(attn_probs) # B, text_timesteps + cross_attention_scores, all_heads_cross_attn_scores = self.get_cross_attention_scores( + attn_probs + ) # B, text_timesteps alignment_attention_scores = cross_attention_scores if estimate_alignment_from_layers is not None: - alignment_attention_scores, _ = self.get_cross_attention_scores(attn_probs, filter_layers=estimate_alignment_from_layers) # B, text_timesteps + alignment_attention_scores, _ = self.get_cross_attention_scores( + attn_probs, filter_layers=estimate_alignment_from_layers + ) # B, text_timesteps cross_attention_scores_all_timesteps.append(cross_attention_scores) all_heads_cross_attn_scores_all_timesteps.append(all_heads_cross_attn_scores) @@ -1929,7 +2172,7 @@ def infer_batch( text_lens=context_tensors['text_lens'], lookahead_window_size=lookahead_window_size, attended_timestep_counter=attended_timestep_counter, - batch_size=batch_size + batch_size=batch_size, ) last_attended_timesteps.append(text_time_step_attended) _attn_prior, unfinished_texts, finished_texts_counter = self.construct_inference_prior( @@ -1942,18 +2185,20 @@ def infer_batch( finished_texts_counter=finished_texts_counter, end_indices=end_indices, lookahead_window_size=lookahead_window_size, - batch_size=batch_size + batch_size=batch_size, ) - finished_items = {k: v for k, v in finished_texts_counter.items() if v >= 20} # Items that have been close to the end for atleast 20 timesteps + finished_items = { + k: v for k, v in finished_texts_counter.items() if v >= 20 + } # Items that have been close to the end for atleast 20 timesteps unfinished_items = {k: v for k, v in unfinished_texts.items() if v} - all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) + all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) if use_local_transformer_for_inference: - if self.local_transformer_type == LocalTransformerType.AR : + if self.local_transformer_type == LocalTransformerType.AR: # Autoregressive sampling with local transformer audio_codes_next = self.local_transformer_sample_autoregressive( - dec_output=dec_out[:,-1,:], + dec_output=dec_out[:, -1, :], temperature=temperature, topk=topk, unfinished_items=unfinished_items, @@ -1964,7 +2209,7 @@ def infer_batch( ) elif self.local_transformer_type == LocalTransformerType.MASKGIT: audio_codes_next = self.local_transformer_sample_maskgit( - dec_output=dec_out[:,-1,:], + dec_output=dec_out[:, -1, :], temperature=temperature, topk=topk, unfinished_items=unfinished_items, @@ -1975,51 +2220,68 @@ def infer_batch( noise_scale=maskgit_noise_scale, fixed_schedule=maskgit_fixed_schedule, dynamic_cfg_scale=maskgit_dynamic_cfg_scale, - sampling_type=maskgit_sampling_type + sampling_type=maskgit_sampling_type, ) else: - raise ValueError(f"Local transformer inference requested by but local transformer type is {self.local_transformer_type}") + raise ValueError( + f"Local transformer inference requested by but local transformer type is {self.local_transformer_type}" + ) else: # Parallel sampling from all codebooks - audio_codes_next = self.sample_codes_from_logits(all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unfinished_items, finished_items=finished_items) # (B, num_codebooks, frame_stacking_factor) - all_codes_next_argmax = self.sample_codes_from_logits(all_code_logits_t, temperature=0.01, unfinished_items=unfinished_items, finished_items=finished_items) # (B, num_codebooks, frame_stacking_factor) + audio_codes_next = self.sample_codes_from_logits( + all_code_logits_t, + temperature=temperature, + topk=topk, + unfinished_items=unfinished_items, + finished_items=finished_items, + ) # (B, num_codebooks, frame_stacking_factor) + all_codes_next_argmax = self.sample_codes_from_logits( + all_code_logits_t, + temperature=0.01, + unfinished_items=unfinished_items, + finished_items=finished_items, + ) # (B, num_codebooks, frame_stacking_factor) for item_idx in range(all_codes_next_argmax.size(0)): if item_idx not in end_indices: # check for EOS (including within the frame stack) eos_frame_multinomial = self.find_eos_frame_index(audio_codes_next[item_idx]) - eos_frame_argmax = self.find_eos_frame_index(all_codes_next_argmax[item_idx]) - eos_frame_multinomial = eos_frame_multinomial if eos_frame_multinomial is not None else float('inf') + eos_frame_argmax = self.find_eos_frame_index(all_codes_next_argmax[item_idx]) + eos_frame_multinomial = ( + eos_frame_multinomial if eos_frame_multinomial is not None else float('inf') + ) eos_frame_argmax = eos_frame_argmax if eos_frame_argmax is not None else float('inf') # pick minimum of the two frame_index = min(eos_frame_multinomial, eos_frame_argmax) if frame_index != float('inf'): - global_index = idx * self.frame_stacking_factor + frame_index + global_index = idx * self.frame_stacking_factor + frame_index end_indices[item_idx] = global_index print(f"End detected for item {item_idx} at decoder timestep: {idx}") all_predictions.append(audio_codes_next) - audio_codes_input = torch.cat( - [audio_codes_input, audio_codes_next], dim=-1 - ) # (B, C, T') - audio_codes_lens = audio_codes_lens + 1 # already in stacked form + audio_codes_input = torch.cat([audio_codes_input, audio_codes_next], dim=-1) # (B, C, T') + audio_codes_lens = audio_codes_lens + 1 # already in stacked form audio_codes_mask = get_mask_from_lengths(audio_codes_lens) if len(end_indices) == text.size(0) and len(all_predictions) >= 4: # Codec must be of atleast 4 timesteps to be decoded properly print("All ends reached") break tts_generation_time = time.time() - start_time - tts_generation_time_per_frame = tts_generation_time / (len(all_predictions)*self.frame_stacking_factor) + tts_generation_time_per_frame = tts_generation_time / (len(all_predictions) * self.frame_stacking_factor) # Concatenate the list of predictions along the time dimension. Note that when frame stacking is on, # this also undoes the stacking. predicted_codes = torch.cat(all_predictions, dim=-1) # (B, num_codebooks, T') - predicted_lens = [end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0))] # Ensure that the codec is atleast of length 4 + predicted_lens = [ + end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0)) + ] # Ensure that the codec is atleast of length 4 predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long() predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens) end_time = time.time() - total_audio_duration_generated = (predicted_audio_lens.max().item() * predicted_audio_lens.shape[0])/self.sample_rate + total_audio_duration_generated = ( + predicted_audio_lens.max().item() * predicted_audio_lens.shape[0] + ) / self.sample_rate rtf = total_audio_duration_generated / (end_time - start_time) rtf_metrics = { 'rtf': rtf, @@ -2032,10 +2294,22 @@ def infer_batch( torch.cuda.empty_cache() if return_cross_attn_probs: cross_attention_maps, headwise_cross_attention_maps = self.get_inference_attention_plots( - cross_attention_scores_all_timesteps, all_heads_cross_attn_scores_all_timesteps, - context_tensors['text_lens'], predicted_codes_lens, text.size(0), compute_all_heads_attn_maps + cross_attention_scores_all_timesteps, + all_heads_cross_attn_scores_all_timesteps, + context_tensors['text_lens'], + predicted_codes_lens, + text.size(0), + compute_all_heads_attn_maps, + ) + return ( + predicted_audio, + predicted_audio_lens, + predicted_codes, + predicted_codes_lens, + rtf_metrics, + cross_attention_maps, + headwise_cross_attention_maps, ) - return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, rtf_metrics, cross_attention_maps, headwise_cross_attention_maps else: # For backward compatibility return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, rtf_metrics @@ -2060,7 +2334,9 @@ def test_step(self, batch, batch_idx): is_wandb = isinstance(logger, WandbLogger) is_tb = isinstance(logger, TensorBoardLogger) if not is_wandb and not is_tb: - raise ValueError(f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported.") + raise ValueError( + f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." + ) for idx in range(predicted_audio.size(0)): predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() @@ -2100,7 +2376,16 @@ def on_validation_epoch_end(self): # log val_loss in the same group as the other val metrics. self.log("val/loss", val_loss, prog_bar=True, sync_dist=True) # ensure val_loss is available for epoch-level checkpointing and filename generation without cluttering wandb logs. - self.log("val_loss", val_loss, prog_bar=False, sync_dist=True, on_step=False, on_epoch=True, logger=False, enable_graph=False) + self.log( + "val_loss", + val_loss, + prog_bar=False, + sync_dist=True, + on_step=False, + on_epoch=True, + logger=False, + enable_graph=False, + ) self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True) @@ -2215,10 +2500,7 @@ def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: if dataset_cfg.dataloader_params.num_workers == 0: persistent_workers = False # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) - dataset.text_tokenizer = setup_tokenizers( - all_tokenizers_config=self.cfg.text_tokenizers, - mode='test' - ) + dataset.text_tokenizer = setup_tokenizers(all_tokenizers_config=self.cfg.text_tokenizers, mode='test') data_loader = torch.utils.data.DataLoader( dataset, diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index 3ed9107c4352..8c72389b6bd0 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -32,6 +32,7 @@ try: import torchaudio from torchaudio.pipelines import SQUIM_OBJECTIVE + HAVE_TORCHAUDIO = True except ImportError: HAVE_TORCHAUDIO = False @@ -48,14 +49,19 @@ class MagpieTTSModelOfflinePODataGen(MagpieTTSModel): def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) if cfg.get('pref_set_language', "en") == "en": - self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b") + self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( + model_name="nvidia/parakeet-ctc-0.6b" + ) self.eval_asr_model.freeze() - self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name='titanet_large' + ) self.eval_speaker_verification_model.freeze() if cfg.get('load_whisper_model', False): from transformers import WhisperForConditionalGeneration, WhisperProcessor + self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") self.whisper_model.eval() @@ -73,14 +79,14 @@ def test_step(self, batch, batch_idx): temperature=temperature, topk=topk, use_cfg=use_cfg, - cfg_scale=cfg_scale + cfg_scale=cfg_scale, ) predicted_audio_paths = [] audio_durations = [] batch_invalid = False for idx in range(predicted_audio.size(0)): predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() - predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]] + predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] item_idx = batch_idx * test_dl_batch_size + idx # Save the predicted audio log_dir = self.logger.log_dir @@ -92,31 +98,54 @@ def test_step(self, batch, batch_idx): sf.write(audio_path, predicted_audio_np, self.sample_rate) predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) - predicted_codes_torch = predicted_codes_torch[:, :predicted_codes_lens[idx]] - torch.save(predicted_codes_torch, os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt')) + predicted_codes_torch = predicted_codes_torch[:, : predicted_codes_lens[idx]] + torch.save( + predicted_codes_torch, + os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt'), + ) predicted_audio_paths.append(audio_path) if not batch_invalid: with torch.no_grad(): try: if self.cfg.get("pref_set_language", "en") == "en": - pred_transcripts = self.eval_asr_model.transcribe(predicted_audio_paths, batch_size=len(predicted_audio_paths)) - pred_transcripts = [ process_text_for_cer(transcript.text) for transcript in pred_transcripts ] + pred_transcripts = self.eval_asr_model.transcribe( + predicted_audio_paths, batch_size=len(predicted_audio_paths) + ) + pred_transcripts = [ + process_text_for_cer(transcript.text) for transcript in pred_transcripts + ] else: pred_transcripts = [] for audio_path in predicted_audio_paths: - transcript = transcribe_with_whisper(audio_path, self.cfg.pref_set_language, self.whisper_processor, self.whisper_model, self.device) + transcript = transcribe_with_whisper( + audio_path, + self.cfg.pref_set_language, + self.whisper_processor, + self.whisper_model, + self.device, + ) pred_transcripts.append(transcript) - pred_transcripts = [process_text_for_cer(transcript) for transcript in pred_transcripts] + pred_transcripts = [ + process_text_for_cer(transcript) for transcript in pred_transcripts + ] except Exception as e: - assert (predicted_audio_lens[idx] < 1000).any(), f"Expected short audio file to be the only cause of ASR errors, but got error with lengths {predicted_audio_lens}" + assert ( + predicted_audio_lens[idx] < 1000 + ).any(), f"Expected short audio file to be the only cause of ASR errors, but got error with lengths {predicted_audio_lens}" logging.warning(f"Exception during ASR transcription: {e}") - logging.warning(f"Skipping processing of the batch; generating metrics indicating a WER of 100% and Speaker Similarity of 0.0") + logging.warning( + f"Skipping processing of the batch; generating metrics indicating a WER of 100% and Speaker Similarity of 0.0" + ) batch_invalid = True - continue # don't break since we want to continue building audio durations list - pred_speaker_embeddings = get_speaker_embeddings_from_filepaths(predicted_audio_paths, self.eval_speaker_verification_model, self.device) - gt_speaker_embeddings = get_speaker_embeddings_from_filepaths(batch['audio_filepaths'], self.eval_speaker_verification_model, self.device) + continue # don't break since we want to continue building audio durations list + pred_speaker_embeddings = get_speaker_embeddings_from_filepaths( + predicted_audio_paths, self.eval_speaker_verification_model, self.device + ) + gt_speaker_embeddings = get_speaker_embeddings_from_filepaths( + batch['audio_filepaths'], self.eval_speaker_verification_model, self.device + ) for idx in range(predicted_audio.size(0)): if not batch_invalid: @@ -139,27 +168,31 @@ def test_step(self, batch, batch_idx): cer_gt = 1.0 wer_gt = 1.0 spk_similarity = 0.0 - pred_transcript = "" # do not change this string; subsequent processing relies on it + pred_transcript = "" # do not change this string; subsequent processing relies on it gt_transcript = process_text_for_cer(batch['raw_texts'][idx]) item_metrics = { 'cer_gt': float(cer_gt), 'wer_gt': float(wer_gt), - 'duration' : audio_durations[idx], + 'duration': audio_durations[idx], 'spk_similarity': float(spk_similarity), 'pred_transcript': pred_transcript, 'gt_transcript': gt_transcript, } - with open(os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w') as f: + with open( + os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w' + ) as f: json.dump(item_metrics, f) + class MagpieTTSModelOfflinePO(MagpieTTSModel): """ MagpieTTS_Model_OfflinePO is a class that extends MagpieTTS_Model to support offline preference optimization (DPO, IPO, RPO). Set cfg.model.dpo_loss_type to 'dpo', 'ipo', or 'rpo' to use the corresponding loss. """ + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) ref_model_cfg = copy.deepcopy(cfg) @@ -168,7 +201,9 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): ref_model_cfg.validation_ds = None self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) print("Loading reference model from checkpoint") - self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) + self._reference_model.load_state_dict( + torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict'] + ) self._reference_model.freeze() self._reference_model._no_state_dict = True print("Reference model loaded and frozen") @@ -199,17 +234,20 @@ def _get_batch_logps(self, logits, labels, loss_mask, average_log_prob=False): else: return (per_token_logps * loss_mask).sum(-1) - def preference_loss(self, policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - chosen_gt_rewards=None, - rejected_gt_rewards=None, - beta=0.2, - gt_reward_scale=1.0, - label_smoothing=0, - loss_type="dpo", - reference_free=False): + def preference_loss( + self, + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + chosen_gt_rewards=None, + rejected_gt_rewards=None, + beta=0.2, + gt_reward_scale=1.0, + label_smoothing=0, + loss_type="dpo", + reference_free=False, + ): """Compute the DPO loss for a batch of policy and reference model log probabilities. Referenced From: https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py Args: @@ -239,7 +277,7 @@ def preference_loss(self, policy_chosen_logps, # logits is the same as rewards_delta in NeMo aligner: https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 if loss_type == "ipo": - losses = (logits - 1/(2 * beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf + losses = (logits - 1 / (2 * beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf elif loss_type == "rpo": # https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 logbeta_hat_chosen = torch.nn.functional.logsigmoid(beta * logits) @@ -247,17 +285,18 @@ def preference_loss(self, policy_chosen_logps, gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) - losses = ( - torch.exp(logalpha_hat_chosen) * (logalpha_hat_chosen - logbeta_hat_chosen) - + torch.exp(logalpha_hat_rejected) * (logalpha_hat_rejected - logbeta_hat_rejected) - ) + losses = torch.exp(logalpha_hat_chosen) * (logalpha_hat_chosen - logbeta_hat_chosen) + torch.exp( + logalpha_hat_rejected + ) * (logalpha_hat_rejected - logbeta_hat_rejected) elif loss_type == "rpo_sq": gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) losses = (beta * logits - gt_rewards_delta) ** 2 elif loss_type == "dpo": # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) F = torch.nn.functional - losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing + losses = ( + -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing + ) else: raise NotImplementedError("loss type {} is not implemented".format(loss_type)) @@ -289,14 +328,26 @@ def process_batch_dpo(self, batch_chosen_rejected): ref_codebook_logits_chosen = reference_model_output_chosen['logits'][:, :, si:ei] ref_codebook_logits_rejected = reference_model_output_rejected['logits'][:, :, si:ei] - codebook_labels_chosen = model_output_chosen['audio_codes_target'][:,codebook_idx] - codebook_labels_rejected = model_output_rejected['audio_codes_target'][:,codebook_idx] + codebook_labels_chosen = model_output_chosen['audio_codes_target'][:, codebook_idx] + codebook_labels_rejected = model_output_rejected['audio_codes_target'][:, codebook_idx] - codebook_log_probs_chosen = self._get_batch_logps(codebook_logits_chosen, codebook_labels_chosen, model_output_chosen['loss_mask'][:,codebook_idx]) - codebook_log_probs_rejected = self._get_batch_logps(codebook_logits_rejected, codebook_labels_rejected, model_output_rejected['loss_mask'][:,codebook_idx]) + codebook_log_probs_chosen = self._get_batch_logps( + codebook_logits_chosen, codebook_labels_chosen, model_output_chosen['loss_mask'][:, codebook_idx] + ) + codebook_log_probs_rejected = self._get_batch_logps( + codebook_logits_rejected, codebook_labels_rejected, model_output_rejected['loss_mask'][:, codebook_idx] + ) with torch.no_grad(): - ref_codebook_log_probs_chosen = self._get_batch_logps(ref_codebook_logits_chosen, codebook_labels_chosen, reference_model_output_chosen['loss_mask'][:,codebook_idx]) - ref_codebook_log_probs_rejected = self._get_batch_logps(ref_codebook_logits_rejected, codebook_labels_rejected, reference_model_output_rejected['loss_mask'][:,codebook_idx]) + ref_codebook_log_probs_chosen = self._get_batch_logps( + ref_codebook_logits_chosen, + codebook_labels_chosen, + reference_model_output_chosen['loss_mask'][:, codebook_idx], + ) + ref_codebook_log_probs_rejected = self._get_batch_logps( + ref_codebook_logits_rejected, + codebook_labels_rejected, + reference_model_output_rejected['loss_mask'][:, codebook_idx], + ) if chosen_policy_logprobs is None: chosen_policy_logprobs = codebook_log_probs_chosen @@ -359,12 +410,14 @@ def validation_step(self, batch, batch_idx): val_sft_loss = dpo_outputs['sft_loss'] val_alignment_loss = dpo_outputs['alignment_loss'] - self.validation_step_outputs.append({ - 'val_loss': val_loss, - 'val_pref_loss': val_pref_loss, - 'val_sft_loss': val_sft_loss, - 'val_alignment_loss': val_alignment_loss, - }) + self.validation_step_outputs.append( + { + 'val_loss': val_loss, + 'val_pref_loss': val_pref_loss, + 'val_sft_loss': val_sft_loss, + 'val_alignment_loss': val_alignment_loss, + } + ) def on_validation_epoch_end(self): def collect(key): @@ -388,11 +441,13 @@ def collect(key): self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) self.validation_step_outputs.clear() + class MagpieTTSModelOnlinePO(MagpieTTSModel): """ MagpieTTS_Model_OnlinePO is a class that extends MagpieTTS_Model to support online preference optimization (GRPO). """ + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg, trainer) # Copy cfg @@ -401,31 +456,39 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): ref_model_cfg.train_ds = None ref_model_cfg.validation_ds = None - self.reference_free = self.cfg.get('reference_free', False) # True means we dont use the reference model + self.reference_free = self.cfg.get('reference_free', False) # True means we dont use the reference model if not self.reference_free: self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) print("Loading reference model from checkpoint") - self._reference_model.load_state_dict(torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict']) + self._reference_model.load_state_dict( + torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict'] + ) self._reference_model.freeze() self._reference_model._no_state_dict = True print("Reference model loaded and frozen") if cfg.get('reward_asr_model', "nemo") == "nemo": - self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b") + self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( + model_name="nvidia/parakeet-ctc-0.6b" + ) self.eval_asr_model.freeze() elif cfg.get('reward_asr_model', "nemo") == "whisper": from transformers import WhisperForConditionalGeneration, WhisperProcessor + self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") self.whisper_model.eval() else: raise ValueError(f"Unknown reward_asr_model: {cfg.reward_asr_model}") - self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name='titanet_large' + ) self.eval_speaker_verification_model.freeze() if cfg.get('load_whisper_model', False): from transformers import WhisperForConditionalGeneration, WhisperProcessor + self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") self.whisper_model.eval() @@ -434,10 +497,12 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if use_pesq: assert HAVE_TORCHAUDIO, "torchaudio is required for PESQ reward" self.squim_objective_model = SQUIM_OBJECTIVE.get_model() - + self.loss_type = self.cfg.get('loss_type', 'grpo') if self.loss_type not in ['grpo', 'dr_grpo']: - raise ValueError(f"Received loss_type of {self.loss_type}, but the model only accepts one of ['grpo', 'dr_grpo']") + raise ValueError( + f"Received loss_type of {self.loss_type}, but the model only accepts one of ['grpo', 'dr_grpo']" + ) self.scale_rewards = self.cfg.get('scale_rewards', True) self.max_decoder_steps = self.cfg.get('max_decoder_steps', 430) # If the best record in the group is above this threshold, we will not use that group for training @@ -446,17 +511,22 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # If the worst record in the group exceeds this threshold, we will not use that group for training # Setting this to 1.0, because we clamp the ASR rewards to be in [0, 1] for OnlinePO self.worst_cer_threshold = self.cfg.get('worst_cer_threshold', 1.0) - def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = super().state_dict(destination, prefix, keep_vars) - keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model', - 'eval_asr_model', 'eval_speaker_verification_model', 'whisper_model'] + keys_substrings_to_exclude = [ + '_speaker_verification_model', + '_codec_model', + '_reference_model', + 'eval_asr_model', + 'eval_speaker_verification_model', + 'whisper_model', + ] for key in list(state_dict.keys()): if any([substring in key for substring in keys_substrings_to_exclude]): del state_dict[key] return state_dict - + def _get_per_token_logps(self, logits, labels, loss_mask): """Compute the log probabilities of the given labels under the given logits. @@ -482,7 +552,9 @@ def repeat_items_in_batch(self, batch, num_repeats): repeated_batch[key] = repeated_value return repeated_batch - def generate_and_reward(self, batch, num_generations_per_item, mode='train', use_local_transformer_for_inference=False): + def generate_and_reward( + self, batch, num_generations_per_item, mode='train', use_local_transformer_for_inference=False + ): batch_repeated = self.repeat_items_in_batch(batch, num_generations_per_item) temperature = self.cfg.get('inference_temperature', 0.7) topk = self.cfg.get('inference_topk', 80) @@ -494,7 +566,7 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train', use # Randomly set use_cfg based on the given probability use_cfg = random.random() < self.cfg.inference_cfg_prob cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) - + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( batch_repeated, max_decoder_steps=self.max_decoder_steps, @@ -503,13 +575,13 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train', use use_cfg=use_cfg, cfg_scale=cfg_scale, use_local_transformer_for_inference=use_local_transformer_for_inference, - use_LT_kv_cache=False, # We don't use KV caching for local transformer in GRPO due to issues. + use_LT_kv_cache=False, # We don't use KV caching for local transformer in GRPO due to issues. ) predicted_audio_paths = [] audio_durations = [] for idx in range(predicted_audio.size(0)): predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() - predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]] + predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] if predicted_audio_np.shape[0] < 1000: # Corner case to handle short audio files predicted_audio_np = np.pad(predicted_audio_np, (0, 1000 - predicted_audio_np.shape[0])) @@ -523,25 +595,35 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train', use sf.write(audio_path, predicted_audio_np, self.sample_rate) predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) - predicted_codes_torch = predicted_codes_torch[:, :predicted_codes_lens[idx]] # C, T - torch.save(predicted_codes_torch, os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt')) + predicted_codes_torch = predicted_codes_torch[:, : predicted_codes_lens[idx]] # C, T + torch.save( + predicted_codes_torch, + os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt'), + ) predicted_audio_paths.append(audio_path) with torch.no_grad(): if self.cfg.get("reward_asr_model", "nemo") == "nemo": - pred_transcripts = self.eval_asr_model.transcribe(predicted_audio_paths, batch_size=len(predicted_audio_paths)) - pred_transcripts = [ process_text_for_cer(transcript.text) for transcript in pred_transcripts ] + pred_transcripts = self.eval_asr_model.transcribe( + predicted_audio_paths, batch_size=len(predicted_audio_paths) + ) + pred_transcripts = [process_text_for_cer(transcript.text) for transcript in pred_transcripts] elif self.cfg.get("reward_asr_model", "nemo") == "whisper": pred_transcripts = [] for item_idx, audio_path in enumerate(predicted_audio_paths): language = batch_repeated['languages'][item_idx] - transcript = transcribe_with_whisper(audio_path, language, self.whisper_processor, self.whisper_model, self.device) + transcript = transcribe_with_whisper( + audio_path, language, self.whisper_processor, self.whisper_model, self.device + ) pred_transcripts.append(transcript) pred_transcripts = [process_text_for_cer(transcript) for transcript in pred_transcripts] - pred_speaker_embeddings = get_speaker_embeddings_from_filepaths(predicted_audio_paths, self.eval_speaker_verification_model, self.device) - gt_speaker_embeddings = get_speaker_embeddings_from_filepaths(batch_repeated['audio_filepaths'], self.eval_speaker_verification_model, self.device) - + pred_speaker_embeddings = get_speaker_embeddings_from_filepaths( + predicted_audio_paths, self.eval_speaker_verification_model, self.device + ) + gt_speaker_embeddings = get_speaker_embeddings_from_filepaths( + batch_repeated['audio_filepaths'], self.eval_speaker_verification_model, self.device + ) batch_metrics = [] cer_reward_weight = self.cfg.get('cer_reward_weight', 0.5) @@ -572,23 +654,27 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train', use item_metrics = { 'cer_gt': float(cer_gt), 'wer_gt': float(wer_gt), - 'duration' : audio_durations[idx], + 'duration': audio_durations[idx], 'spk_similarity': float(spk_similarity), 'pred_transcript': pred_transcript, 'gt_transcript': gt_transcript, 'codes_len': predicted_codes_lens[idx].item(), - 'pesq' : pesq_hyp if use_pesq else 0.0, + 'pesq': pesq_hyp if use_pesq else 0.0, } - with open(os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w') as f: + with open( + os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w' + ) as f: json.dump(item_metrics, f) batch_metrics.append(item_metrics) num_groups = len(batch['audio_filepaths']) - best_ssim_achievable = self.cfg.get("best_ssim_achievable", 0.9) # Examples with this speaker similarity or higher will have SSIM reward of 1 - mean_cer_dataset = self.cfg.get("mean_cer_dataset", 0.1) # CER equal to this value will have reward of 0.5 - mean_ssim_dataset = self.cfg.get("mean_ssim_dataset", 0.6) # SSIM equal to this value will have reward of 0.5 + best_ssim_achievable = self.cfg.get( + "best_ssim_achievable", 0.9 + ) # Examples with this speaker similarity or higher will have SSIM reward of 1 + mean_cer_dataset = self.cfg.get("mean_cer_dataset", 0.1) # CER equal to this value will have reward of 0.5 + mean_ssim_dataset = self.cfg.get("mean_ssim_dataset", 0.6) # SSIM equal to this value will have reward of 0.5 all_groups_mean_reward = 0.0 all_groups_std_reward = 0.0 group_validities = [] @@ -606,18 +692,20 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train', use # Reward for best CER and best speaker similarity should be 1 item_cer = batch_metrics[idx]['cer_gt'] item_ssim = batch_metrics[idx]['spk_similarity'] - item_cer = min( max(item_cer, 0.0), 1.0) - item_ssim = max( min(item_ssim, best_ssim_achievable), 0.0) + item_cer = min(max(item_cer, 0.0), 1.0) + item_ssim = max(min(item_ssim, best_ssim_achievable), 0.0) item_pesq = batch_metrics[idx]['pesq'] group_best_cer = min(group_best_cer, item_cer) group_worst_cer = max(group_worst_cer, item_cer) - + if item_cer <= mean_cer_dataset: - cer_reward = 0.5 + 0.5 * (mean_cer_dataset - item_cer) / mean_cer_dataset # 0.5 to 1 + cer_reward = 0.5 + 0.5 * (mean_cer_dataset - item_cer) / mean_cer_dataset # 0.5 to 1 else: - cer_reward = 0.5 - 0.5 * (item_cer - mean_cer_dataset) / (1 - mean_cer_dataset) # 0 to 0.5 + cer_reward = 0.5 - 0.5 * (item_cer - mean_cer_dataset) / (1 - mean_cer_dataset) # 0 to 0.5 if item_ssim >= mean_ssim_dataset: - spk_similarity_reward = 0.5 + 0.5 * (item_ssim - mean_ssim_dataset) / (best_ssim_achievable - mean_ssim_dataset) + spk_similarity_reward = 0.5 + 0.5 * (item_ssim - mean_ssim_dataset) / ( + best_ssim_achievable - mean_ssim_dataset + ) else: spk_similarity_reward = 0.5 - 0.5 * (mean_ssim_dataset - item_ssim) / (mean_ssim_dataset) if use_pesq: @@ -625,26 +713,47 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train', use else: pesq_reward = 0.0 - batch_metrics[idx]['reward'] = cer_reward * cer_reward_weight + spk_similarity_reward * ssim_reward_weight + pesq_reward * pesq_reward_weight + batch_metrics[idx]['reward'] = ( + cer_reward * cer_reward_weight + + spk_similarity_reward * ssim_reward_weight + + pesq_reward * pesq_reward_weight + ) - if (batch_metrics[idx]['codes_len'] >= 425) or (batch_metrics[idx]['codes_len'] <= 3): # TODO: Remove hardcoded lengths + if (batch_metrics[idx]['codes_len'] >= 425) or ( + batch_metrics[idx]['codes_len'] <= 3 + ): # TODO: Remove hardcoded lengths # This means it did not complete the sentence or generated an extremely short sentence batch_metrics[idx]['reward'] = 0.0 - print("Item idx: ", idx, " CER: ", item_cer, " SSIM: ", item_ssim, " Reward: ", batch_metrics[idx]['reward'], " Codes len: ", batch_metrics[idx]['codes_len']) + print( + "Item idx: ", + idx, + " CER: ", + item_cer, + " SSIM: ", + item_ssim, + " Reward: ", + batch_metrics[idx]['reward'], + " Codes len: ", + batch_metrics[idx]['codes_len'], + ) batch_metrics[idx]['cer_reward'] = cer_reward batch_metrics[idx]['spk_similarity_reward'] = spk_similarity_reward batch_metrics[idx]['pesq_reward'] = pesq_reward mean_reward += batch_metrics[idx]['reward'] group_rewards.append(batch_metrics[idx]['reward']) - - if (group_best_cer > self.best_cer_threshold): + + if group_best_cer > self.best_cer_threshold: is_group_valid = False - print(f"Group {group_idx} has best CER {group_best_cer} which is above the threshold {self.best_cer_threshold}. Group is invalid.") - - if (group_worst_cer > self.worst_cer_threshold): + print( + f"Group {group_idx} has best CER {group_best_cer} which is above the threshold {self.best_cer_threshold}. Group is invalid." + ) + + if group_worst_cer > self.worst_cer_threshold: is_group_valid = False - print(f"Group {group_idx} has worst CER {group_worst_cer} which is above the threshold {self.worst_cer_threshold}. Group is invalid.") - + print( + f"Group {group_idx} has worst CER {group_worst_cer} which is above the threshold {self.worst_cer_threshold}. Group is invalid." + ) + for _ in range(num_generations_per_item): group_validities.append(is_group_valid) @@ -662,7 +771,7 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train', use advantages = [x['advantage'] for x in batch_metrics] advantages = torch.tensor(advantages, device=self.device) print("Mean reward: ", all_groups_mean_reward) - + group_validities = torch.tensor(group_validities, device=self.device) return { 'mean_reward': torch.tensor(all_groups_mean_reward, device=self.device), @@ -675,7 +784,6 @@ def generate_and_reward(self, batch, num_generations_per_item, mode='train', use 'group_validities': group_validities, } - def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): use_kv_cache_during_online_po = self.cfg.get("use_kv_cache_during_online_po", False) if use_kv_cache_during_online_po: @@ -688,14 +796,14 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): if use_local_transformer_prob > 0.0 and mode == 'train': use_local_transformer_for_inference = random.random() < use_local_transformer_prob logits_key = 'local_transformer_logits' - + with torch.no_grad(): self.eval() generated_codes_and_metrics = self.generate_and_reward( batch, n_generations_per_item, mode, - use_local_transformer_for_inference=use_local_transformer_for_inference + use_local_transformer_for_inference=use_local_transformer_for_inference, ) self.train() @@ -704,19 +812,29 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): self.decoder.reset_cache(use_cache=False) batch_repeated = generated_codes_and_metrics['batch_repeated'] - predicted_codes = generated_codes_and_metrics['predicted_codes'] # B, 8, T - predicted_codes_lens = generated_codes_and_metrics['predicted_codes_lens'] # B - predicted_codes = predicted_codes[:,:,:predicted_codes_lens.max()] + predicted_codes = generated_codes_and_metrics['predicted_codes'] # B, 8, T + predicted_codes_lens = generated_codes_and_metrics['predicted_codes_lens'] # B + predicted_codes = predicted_codes[:, :, : predicted_codes_lens.max()] - advantages = generated_codes_and_metrics['advantages'] # B + advantages = generated_codes_and_metrics['advantages'] # B # Add extra tokens for BOS and EOS - bos_tensor = torch.full((predicted_codes.size(0), predicted_codes.size(1), 1), self.audio_bos_id, dtype=predicted_codes.dtype, device=predicted_codes.device) - padding_tensor = torch.full((predicted_codes.size(0), predicted_codes.size(1), 1), 0, dtype=predicted_codes.dtype, device=predicted_codes.device) + bos_tensor = torch.full( + (predicted_codes.size(0), predicted_codes.size(1), 1), + self.audio_bos_id, + dtype=predicted_codes.dtype, + device=predicted_codes.device, + ) + padding_tensor = torch.full( + (predicted_codes.size(0), predicted_codes.size(1), 1), + 0, + dtype=predicted_codes.dtype, + device=predicted_codes.device, + ) predicted_codes = torch.cat([bos_tensor, predicted_codes, padding_tensor], dim=2) for idx in range(predicted_codes.size(0)): - predicted_codes[idx, :, predicted_codes_lens[idx]+1] = self.audio_eos_id # Accounts for BOS + predicted_codes[idx, :, predicted_codes_lens[idx] + 1] = self.audio_eos_id # Accounts for BOS batch_repeated['audio_codes'] = predicted_codes - batch_repeated['audio_codes_lens'] = predicted_codes_lens + 2 # Accounts for BOS and EOS + batch_repeated['audio_codes_lens'] = predicted_codes_lens + 2 # Accounts for BOS and EOS if 'audio' in batch_repeated: del batch_repeated['audio'] if 'audio_lens' in batch_repeated: @@ -731,32 +849,50 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): total_loss = None total_kl = None for codebook_idx in range(self.num_audio_codebooks): - policy_codebook_loss_mask = policy_model_outputs['loss_mask'][:,codebook_idx,:] - reference_codebook_loss_mask = reference_model_output['loss_mask'][:,codebook_idx,:] if not self.reference_free else None + policy_codebook_loss_mask = policy_model_outputs['loss_mask'][:, codebook_idx, :] + reference_codebook_loss_mask = ( + reference_model_output['loss_mask'][:, codebook_idx, :] if not self.reference_free else None + ) si = codebook_idx * self.num_all_tokens_per_codebook ei = si + self.num_all_tokens_per_codebook - - codebook_logits = policy_model_outputs[logits_key][:, :, si:ei] # B, T, C - codebook_labels = batch_repeated['audio_codes'][:,codebook_idx,1:] - - per_token_codebook_log_probs = self._get_per_token_logps(codebook_logits, codebook_labels, policy_codebook_loss_mask) - per_token_loss = -(torch.exp(per_token_codebook_log_probs - per_token_codebook_log_probs.detach()) * advantages.unsqueeze(1)) - group_validities = generated_codes_and_metrics['group_validities'] # B * n_generations_per_item - per_token_loss = per_token_loss * group_validities.unsqueeze(1) # B, T + + codebook_logits = policy_model_outputs[logits_key][:, :, si:ei] # B, T, C + codebook_labels = batch_repeated['audio_codes'][:, codebook_idx, 1:] + + per_token_codebook_log_probs = self._get_per_token_logps( + codebook_logits, codebook_labels, policy_codebook_loss_mask + ) + per_token_loss = -( + torch.exp(per_token_codebook_log_probs - per_token_codebook_log_probs.detach()) + * advantages.unsqueeze(1) + ) + group_validities = generated_codes_and_metrics['group_validities'] # B * n_generations_per_item + per_token_loss = per_token_loss * group_validities.unsqueeze(1) # B, T if not self.reference_free: with torch.no_grad(): ref_codebook_logits = reference_model_output[logits_key][:, :, si:ei] - per_token_ref_codebook_log_probs = self._get_per_token_logps(ref_codebook_logits, codebook_labels, reference_codebook_loss_mask) + per_token_ref_codebook_log_probs = self._get_per_token_logps( + ref_codebook_logits, codebook_labels, reference_codebook_loss_mask + ) # https://github.com/huggingface/trl/blob/ffcb9f4aee725a2bd072d0387afe68a4b1c7967c/trl/trainer/grpo_trainer.py#L703 - per_token_codebook_kl = torch.exp(per_token_ref_codebook_log_probs - per_token_codebook_log_probs) - (per_token_ref_codebook_log_probs - per_token_codebook_log_probs) - 1 + per_token_codebook_kl = ( + torch.exp(per_token_ref_codebook_log_probs - per_token_codebook_log_probs) + - (per_token_ref_codebook_log_probs - per_token_codebook_log_probs) + - 1 + ) per_token_loss = per_token_loss + self.cfg.grpo_beta * per_token_codebook_kl - codebook_kl_loss_mean = ((per_token_codebook_kl * policy_codebook_loss_mask).sum(dim=1) / policy_codebook_loss_mask.sum(dim=1)).mean() + codebook_kl_loss_mean = ( + (per_token_codebook_kl * policy_codebook_loss_mask).sum(dim=1) + / policy_codebook_loss_mask.sum(dim=1) + ).mean() else: codebook_kl_loss_mean = torch.tensor(0.0, device=self.device) if self.loss_type == "grpo": - codebook_loss = ((per_token_loss * policy_codebook_loss_mask).sum(dim=1) / policy_codebook_loss_mask.sum(dim=1)).mean() + codebook_loss = ( + (per_token_loss * policy_codebook_loss_mask).sum(dim=1) / policy_codebook_loss_mask.sum(dim=1) + ).mean() elif self.loss_type == "dr_grpo": # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py total_tokens = per_token_loss.shape[0] * self.max_decoder_steps @@ -772,7 +908,7 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): total_kl += codebook_kl_loss_mean total_loss /= self.num_audio_codebooks - + return { 'mean_reward': generated_codes_and_metrics['mean_reward'], 'std_reward': generated_codes_and_metrics['std_reward'], @@ -798,13 +934,15 @@ def validation_step(self, batch, batch_idx): val_loss = po_outputs['loss'] val_kl_loss = po_outputs['kl_loss'] - self.validation_step_outputs.append({ - 'mean_reward': mean_reward, - 'std_reward': po_outputs['std_reward'], - 'val_loss': val_loss, - 'val_kl_loss': val_kl_loss, - 'batch_metrics': batch_metrics, - }) + self.validation_step_outputs.append( + { + 'mean_reward': mean_reward, + 'std_reward': po_outputs['std_reward'], + 'val_loss': val_loss, + 'val_kl_loss': val_kl_loss, + 'batch_metrics': batch_metrics, + } + ) def on_validation_epoch_end(self): def collect(key): @@ -844,7 +982,6 @@ def collect(key): self.validation_step_outputs.clear() - # Utility functions def process_text_for_cer(input_text): """ @@ -873,31 +1010,34 @@ def process_text_for_cer(input_text): return single_space_text + def get_speaker_embeddings_from_filepaths(filepaths, speaker_verification_model, device): - audio_batch = [] - audio_lengths = [] - for filepath in filepaths: - audio, sr = sf.read(filepath) - if sr != 16000: - audio = librosa.core.resample(audio, orig_sr=sr, target_sr=16000) - audio_tensor = torch.tensor(audio, dtype=torch.float32, device=device) - audio_batch.append(audio_tensor) - audio_lengths.append(audio_tensor.size(0)) - - batch_audio_lens = torch.tensor(audio_lengths, device=device).long() - max_audio_len = int(batch_audio_lens.max().item()) - audio_batch = stack_tensors(audio_batch, max_lens=[max_audio_len]) - - _, speaker_embeddings = speaker_verification_model.forward( - input_signal=audio_batch, - input_signal_length=batch_audio_lens - ) + audio_batch = [] + audio_lengths = [] + for filepath in filepaths: + audio, sr = sf.read(filepath) + if sr != 16000: + audio = librosa.core.resample(audio, orig_sr=sr, target_sr=16000) + audio_tensor = torch.tensor(audio, dtype=torch.float32, device=device) + audio_batch.append(audio_tensor) + audio_lengths.append(audio_tensor.size(0)) + + batch_audio_lens = torch.tensor(audio_lengths, device=device).long() + max_audio_len = int(batch_audio_lens.max().item()) + audio_batch = stack_tensors(audio_batch, max_lens=[max_audio_len]) + + _, speaker_embeddings = speaker_verification_model.forward( + input_signal=audio_batch, input_signal_length=batch_audio_lens + ) + + return speaker_embeddings - return speaker_embeddings def transcribe_with_whisper(audio_filepath, language, whisper_processor, whisper_model, device): speech_array, sampling_rate = librosa.load(audio_filepath, sr=16000) - forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None + forced_decoder_ids = ( + whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None + ) inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features inputs = inputs.to(device) with torch.no_grad(): diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index b9133ac2ea62..ad8d27f0178d 100755 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -55,6 +55,8 @@ from contextlib import contextmanager + + @contextmanager def default_precision(dtype=torch.float32): default_dtype = torch.get_default_dtype() @@ -64,6 +66,7 @@ def default_precision(dtype=torch.float32): finally: torch.set_default_dtype(default_dtype) + def get_padding(kernel_size: int, dilation: int = 1) -> int: return (kernel_size * dilation - dilation) // 2 @@ -683,7 +686,6 @@ def __init__( else: self.activation = torch.nn.Identity() - @property def input_types(self): return { @@ -1423,8 +1425,8 @@ def codebook_size(self): """Returns the size of the codebook for each group.""" return self.fsqs[0].codebook_size - #@property - #def codebook_size(self): + # @property + # def codebook_size(self): # """Returns the size of the implicit codebook.""" # return self.codebook_size_per_group**self.num_groups @@ -1507,8 +1509,7 @@ def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor }, ) def codes_to_indices(self, codes: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: - """Converts a code vector to indices. - """ + """Converts a code vector to indices.""" codes_rearrange = rearrange(codes, 'B D T -> D B T') codes_grouped = codes_rearrange.chunk(self.num_groups, dim=0) indices = [] @@ -2371,7 +2372,7 @@ def forward(self, audio, audio_len): win_length=self.win_length, window=self.window, return_complex=True, - center=False + center=False, ) fft_mag = torch.abs(fft) fft_mag_log = torch.log(fft_mag + self.log_guard) @@ -2594,7 +2595,7 @@ def __init__( out_channels=filters, kernel_size=down_sample_kernel_size, stride=self.down_sample_rate, - activation=activation + activation=activation, ) n_fft, hop_length, win_length = resolution @@ -2624,7 +2625,7 @@ def input_types(self): def output_types(self): return { "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "out_len": NeuralType(tuple('B'), LengthsType()) + "out_len": NeuralType(tuple('B'), LengthsType()), } @typecheck() @@ -2670,7 +2671,7 @@ def __init__( out_channels=filters, kernel_size=down_sample_kernel_size, stride=self.down_sample_rate, - activation=activation + activation=activation, ) self.res_block = ResidualBlockV2( channels=filters, filters=filters, kernel_size=kernel_size, activation=activation @@ -2682,16 +2683,13 @@ def remove_weight_norm(self): @property def input_types(self): - return { - "inputs": NeuralType(('B', 'C', 'T'), VoidType()), - "input_len": NeuralType(tuple('B'), LengthsType()) - } + return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} @property def output_types(self): return { "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "out_len": NeuralType(tuple('B'), LengthsType()) + "out_len": NeuralType(tuple('B'), LengthsType()), } @typecheck() @@ -2737,16 +2735,10 @@ def __init__( input_dim = n_fft // 2 + 1 self.pre_spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) self.pre_conv = Conv1dNorm( - in_channels=input_dim, - out_channels=input_filters, - kernel_size=kernel_size, - activation=activation + in_channels=input_dim, out_channels=input_filters, kernel_size=kernel_size, activation=activation ) self.pre_res_block = ResidualBlockV2( - channels=input_filters, - filters=input_filters, - kernel_size=kernel_size, - activation=activation + channels=input_filters, filters=input_filters, kernel_size=kernel_size, activation=activation ) input_dim = input_filters self.stft_blocks = nn.ModuleList([]) @@ -2766,22 +2758,18 @@ def __init__( down_sample_rate_list = len(down_sample_filter_list) * [2] self.down_sample_blocks = nn.ModuleList([]) - for (filters, down_sample_rate) in zip(down_sample_filter_list, down_sample_rate_list): + for filters, down_sample_rate in zip(down_sample_filter_list, down_sample_rate_list): down_sample_block = DownSampleResidualBlock( channels=input_dim, filters=filters, down_sample_rate=down_sample_rate, kernel_size=kernel_size, - activation=activation + activation=activation, ) self.down_sample_blocks.append(down_sample_block) input_dim = filters - self.post_conv = Conv1dNorm( - in_channels=input_dim, - out_channels=out_dim, - kernel_size=kernel_size - ) + self.post_conv = Conv1dNorm(in_channels=input_dim, out_channels=out_dim, kernel_size=kernel_size) def remove_weight_norm(self): self.encoder.remove_weight_norm() @@ -2807,9 +2795,7 @@ def forward(self, audio, audio_len): encoded = self.pre_res_block(inputs=encoded, input_len=encoded_len) for stft_block in self.stft_blocks: - encoded, encoded_len = stft_block( - inputs=encoded, input_len=encoded_len, audio=audio, audio_len=audio_len - ) + encoded, encoded_len = stft_block(inputs=encoded, input_len=encoded_len, audio=audio, audio_len=audio_len) for down_sample_block in self.down_sample_blocks: encoded, encoded_len = down_sample_block(inputs=encoded, input_len=encoded_len) @@ -2873,5 +2859,7 @@ def convert_original_to_new(self, audio_tokens, audio_lens): def convert_new_to_original(self, audio_tokens, audio_lens): audio_tokens_rearrange = rearrange(audio_tokens, 'B C T -> C B T') audio_codes = self.vector_quantizer_new.decode(indices=audio_tokens_rearrange, input_len=audio_lens) - audio_tokens_original = self.vector_quantizer_original.codes_to_indices(codes=audio_codes, input_len=audio_lens) + audio_tokens_original = self.vector_quantizer_original.codes_to_indices( + codes=audio_codes, input_len=audio_lens + ) return audio_tokens_original diff --git a/nemo/collections/tts/modules/fcd_metric.py b/nemo/collections/tts/modules/fcd_metric.py index c0ca4c406500..2148d9669da1 100644 --- a/nemo/collections/tts/modules/fcd_metric.py +++ b/nemo/collections/tts/modules/fcd_metric.py @@ -23,12 +23,13 @@ be useful to explore). """ +import numpy as np import torch -from torch import nn, Tensor +from torch import Tensor, nn from torchmetrics import Metric -import numpy as np -from nemo.collections.tts.models import AudioCodecModel + from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.tts.models import AudioCodecModel from nemo.utils import logging @@ -37,6 +38,7 @@ class CodecEmbedder(nn.Module): Embeds audio codec codes into the codec's continuous embedding space. Accepts as input either a batch of codes or a path to an audio file. """ + def __init__(self, codec: AudioCodecModel): super().__init__() self.codec = codec @@ -53,12 +55,11 @@ def encode_from_file(self, audio_path: str) -> Tensor: """ Encodes an audio file into audio codec codes. """ - audio_segment = AudioSegment.from_file( - audio_path, target_sr=self.codec.sample_rate) + audio_segment = AudioSegment.from_file(audio_path, target_sr=self.codec.sample_rate) assert np.issubdtype(audio_segment.samples.dtype, np.floating) audio_min = audio_segment.samples.min() audio_max = audio_segment.samples.max() - eps = 0.01 # certain ways of normalizing audio can result in samples that are slightly outside of [-1, 1] + eps = 0.01 # certain ways of normalizing audio can result in samples that are slightly outside of [-1, 1] if audio_min < (-1.0 - eps) or audio_max > (1.0 + eps): logging.warning(f"Audio samples are not normalized: min={audio_min}, max={audio_max}") samples = torch.tensor(audio_segment.samples, device=self.codec.device).unsqueeze(0) @@ -173,9 +174,11 @@ def update(self, codes: Tensor, codes_len: Tensor, is_real: bool): return if codes.shape[1] != self.model.codec.num_codebooks: - logging.warning(f"\nFCD metric received a batch of codes of shape {codes.shape}, but the model has {self.model.codec.num_codebooks} codebooks - skipping update\n") + logging.warning( + f"\nFCD metric received a batch of codes of shape {codes.shape}, but the model has {self.model.codec.num_codebooks} codebooks - skipping update\n" + ) return - + # Dequantize the codes to a continuous representation embeddings = self.model.codes_to_embedding( codes, codes_len diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index d5c0bcaac7cc..a731a5ea97ed 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -13,13 +13,16 @@ # limitations under the License. from __future__ import annotations + from enum import Enum -from nemo.utils.enum import PrettyStrEnum + import torch +from torch import Tensor + from nemo.collections.tts.modules import transformer_2501 from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths -from torch import Tensor from nemo.core.classes.module import NeuralModule +from nemo.utils.enum import PrettyStrEnum class LocalTransformerType(PrettyStrEnum): @@ -67,12 +70,13 @@ def get_forbidden_tokens(base_codebook_size: int, forbid_audio_eos: bool = True) * Set to `False` when internally generating tokens in MagpieTTS sampling * Set to `True` when checking validity of tokens to be returned to user or given to the codec for decoding - """ + """ all_special_tokens = list(SpecialAudioToken) if not forbid_audio_eos: all_special_tokens.remove(SpecialAudioToken.AUDIO_EOS) return [SpecialAudioToken.get_index(token, base_codebook_size) for token in all_special_tokens] + def cosine_schedule(x: torch.Tensor): """ Maps input values from [0, 1] to [1, 0] using the first quadrant of the cosine function. @@ -80,6 +84,7 @@ def cosine_schedule(x: torch.Tensor): """ return torch.cos(x * (torch.pi / 2)) + def build_vocabs(subword_vocab: dict, subword_padding_idx: int, special_vocab: dict = None) -> tuple[dict, dict]: """ Builds the character vocabulary and the mapping from subword ids to character ids. @@ -95,7 +100,7 @@ def build_vocabs(subword_vocab: dict, subword_padding_idx: int, special_vocab: d char_vocab: A dictionary mapping character ids to their corresponding characters. """ org_char_vocab = {subword: subword_id for subword, subword_id in subword_vocab.items() if len(subword) == 1} - + # Add special tokens directly to char vocab if special_vocab is not None: for special_token, special_token_id in special_vocab.items(): @@ -109,27 +114,29 @@ def build_vocabs(subword_vocab: dict, subword_padding_idx: int, special_vocab: d subword_id_to_char_ids = { subword_id: tuple(char_vocab[char] for char in subword) for subword, subword_id in subword_vocab.items() } - + # Creating mapping from subword ids of special tokens to their char ids if special_vocab is not None: for special_token, special_token_id in special_vocab.items(): if special_token in subword_id_to_char_ids: raise ValueError(f"Special token {special_token} already exists in the subword id Vocabulary.") subword_id_to_char_ids[special_token_id] = (char_vocab[special_token],) - + assert max(subword_id_to_char_ids) == len(subword_id_to_char_ids) - 1 - + # Always add padding token to the end of the vocab (this is the convention used in the original code) subword_id_to_char_ids[subword_padding_idx] = (len(char_vocab),) - + return subword_id_to_char_ids, char_vocab + class CharAwareSubwordEncoder(NeuralModule): """ Char-aware subword encoder for the MagpieTTS model. This module takes subword ids as input, maps them to character ids, and then applies a transformer encoder to the character embeddings. The output is a tensor of shape (batch_size, max_subword_length, d_embed). """ + def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: int, special_vocab: dict = None): """ Args: @@ -142,8 +149,10 @@ def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: eg. special_vocab = {'': 30001, '': 30002} """ super().__init__() - self.subword_id_to_char_ids, self.char_vocab = build_vocabs(llm_tokenizer_vocab, subword_padding_idx, special_vocab) - self.embed_tokens = torch.nn.Embedding(self.vocab_size+1, d_embed, padding_idx=self.vocab_size) + self.subword_id_to_char_ids, self.char_vocab = build_vocabs( + llm_tokenizer_vocab, subword_padding_idx, special_vocab + ) + self.embed_tokens = torch.nn.Embedding(self.vocab_size + 1, d_embed, padding_idx=self.vocab_size) self.encoder = transformer_2501.Transformer( n_layers=1, d_model=d_embed, @@ -151,7 +160,7 @@ def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: sa_n_heads=8, kernel_size=1, max_length_causal_mask=256, - use_learnable_pos_emb=True + use_learnable_pos_emb=True, ) @property @@ -195,14 +204,11 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te char_mask = get_mask_from_lengths(char_lengths) char_emb = self.embed_tokens(char_ids) # char emb has the shape [B*T, N, channels], where N is the max number of chars tokens decoded from bpe tokens - x = self.encoder( - x=char_emb, - x_mask=char_mask - )['output'] + x = self.encoder(x=char_emb, x_mask=char_mask)['output'] # Get average embedding over the chars mean_emb = ((x / char_mask.unsqueeze(-1).sum(1, keepdim=True)) * char_mask.unsqueeze(-1)).sum(1) subword_emb = torch.zeros((subword_mask.size(0), subword_mask.size(1), mean_emb.size(-1)), device=device) subword_emb[subword_mask.unsqueeze(-1).expand(-1, -1, mean_emb.size(-1))] = mean_emb.view(-1) - - return subword_emb \ No newline at end of file + + return subword_emb diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index af22c0d1380c..e730ebfe12f1 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -239,8 +239,12 @@ def attn_naive( if self.make_prior_window_strict: # Make sure attention scores are lowest (eps) where prior is zero. min_score = torch.log(torch.tensor(eps)).to(attn_score_log.device) - attn_score_log = attn_score_log.masked_fill(attn_prior == 0, min_score) # Wherever prior is zero, set scores to eps. - attn_score_log = torch.clamp(attn_score_log, min=min_score) # Make sure scores are not less than eps. + attn_score_log = attn_score_log.masked_fill( + attn_prior == 0, min_score + ) # Wherever prior is zero, set scores to eps. + attn_score_log = torch.clamp( + attn_score_log, min=min_score + ) # Make sure scores are not less than eps. attn_prob = F.softmax(attn_score_log, dim=-1) else: attn_prob = F.softmax(attn_score, dim=-1) diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index 6599eb1dd911..5700f66123d3 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -134,9 +134,7 @@ def binarize_attention_parallel(attn, in_lens, out_lens): def get_mask_from_lengths( - lengths: Optional[torch.Tensor] = None, - x: Optional[torch.Tensor] = None, - pad_to_factor: Optional[int] = None + lengths: Optional[torch.Tensor] = None, x: Optional[torch.Tensor] = None, pad_to_factor: Optional[int] = None ) -> torch.Tensor: """Constructs binary mask from a 1D torch tensor of input lengths diff --git a/scripts/magpietts/codec_extraction.py b/scripts/magpietts/codec_extraction.py index 0266bc5e7039..3ecb42bc1504 100644 --- a/scripts/magpietts/codec_extraction.py +++ b/scripts/magpietts/codec_extraction.py @@ -11,17 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse import json -import torch -from torch.utils.data import Dataset, DataLoader +import os + import lightning.pytorch as pl +import torch from lightning.pytorch import Trainer from lightning.pytorch.strategies import DDPStrategy -from nemo.collections.tts.models import AudioCodecModel -import os -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -import argparse from lightning.pytorch.utilities import rank_zero_only +from torch.utils.data import DataLoader, Dataset + +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.tts.models import AudioCodecModel + class AudioDataset(Dataset): def __init__(self, file_lists, base_audio_dirs, dataset_names, out_dir, sample_rate=22050, pad_multiple=1024): @@ -36,18 +39,19 @@ def __init__(self, file_lists, base_audio_dirs, dataset_names, out_dir, sample_r dataset_name = dataset_names[fidx] for file_path in file_list: audio_file_path = os.path.join(base_audio_dir, file_path) - self.combined_file_list.append({ - "file_path": file_path, - "audio_file_path": audio_file_path, - "dataset_name": dataset_name - }) + self.combined_file_list.append( + {"file_path": file_path, "audio_file_path": audio_file_path, "dataset_name": dataset_name} + ) def __len__(self): return len(self.combined_file_list) def get_wav_from_filepath(self, file_path): features = AudioSegment.segment_from_file( - file_path, target_sr=self.sample_rate, n_segments=-1, trim=False, + file_path, + target_sr=self.sample_rate, + n_segments=-1, + trim=False, ) audio_samples = features.samples audio = torch.tensor(audio_samples) @@ -66,7 +70,7 @@ def __getitem__(self, idx): "audio": audio, "audio_length": audio_length, "file_path": file_path, - "codec_file_path": os.path.join(self.out_dir, dataset_name, codec_file_path_rel) + "codec_file_path": os.path.join(self.out_dir, dataset_name, codec_file_path_rel), } def collate_fn(self, batch): @@ -76,9 +80,7 @@ def collate_fn(self, batch): codec_file_paths = [] max_audio_length = max(item["audio_length"].item() for item in batch) for item in batch: - audio = torch.nn.functional.pad( - item["audio"], (0, max_audio_length - item["audio"].size(0)), value=0 - ) + audio = torch.nn.functional.pad(item["audio"], (0, max_audio_length - item["audio"].size(0)), value=0) audios_padded.append(audio) audio_lengths.append(item["audio_length"]) file_paths.append(item["file_path"]) @@ -88,7 +90,7 @@ def collate_fn(self, batch): "audios": torch.stack(audios_padded), "audio_lengths": torch.stack(audio_lengths), "audio_file_paths": file_paths, - "codec_file_paths": codec_file_paths + "codec_file_paths": codec_file_paths, } @@ -108,10 +110,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): codes, codes_lengths = self(batch) for i, file_path in enumerate(batch["codec_file_paths"]): # get directory from file path - item_codes = codes[i, :, :codes_lengths[i]] # 8, T + item_codes = codes[i, :, : codes_lengths[i]] # 8, T torch.save(item_codes.cpu().type(torch.int16), file_path) return None + def read_manifest(manifest_path): records = [] with open(manifest_path, 'r') as f: @@ -122,6 +125,7 @@ def read_manifest(manifest_path): records.append(record) return records + def write_manifest(manifest_path, records): with open(manifest_path, 'w') as f: file_str = "" @@ -131,6 +135,7 @@ def write_manifest(manifest_path, records): f.write(file_str) print("Wrote {} records to: {}".format(len(records), manifest_path)) + @rank_zero_only def update_manifests(manifests, save_dir, dataset_names, codec_model_name): for midx, manifest in enumerate(manifests): @@ -143,14 +148,19 @@ def update_manifests(manifests, save_dir, dataset_names, codec_model_name): assert os.path.exists(audio_codes_path), "Audio codes not found: {}".format(audio_codes_path) if "context_audio_filepath" in record: - context_audio_codes_path = record["context_audio_filepath"].replace(".wav", ".pt").replace(".flac", ".pt") + context_audio_codes_path = ( + record["context_audio_filepath"].replace(".wav", ".pt").replace(".flac", ".pt") + ) context_audio_codes_path = os.path.join(save_dir, dataset_names[midx], context_audio_codes_path) record["context_audio_codes_path"] = context_audio_codes_path if ridx % 10 == 0: - assert os.path.exists(context_audio_codes_path), "Context audio codes not found: {}".format(context_audio_codes_path) + assert os.path.exists(context_audio_codes_path), "Context audio codes not found: {}".format( + context_audio_codes_path + ) write_manifest(manifest.replace(".json", "_withAudioCodes_{}.json".format(codec_model_name)), records) + def prepare_directories(base_save_dir, codec_model_name, manifests, audio_base_dirs, dataset_names): print("In prepare_directories") save_dir = os.path.join(base_save_dir, codec_model_name) @@ -172,6 +182,7 @@ def prepare_directories(base_save_dir, codec_model_name, manifests, audio_base_d print("Created directories for saving audio codes at: ", save_dir, len(file_lists)) return save_dir, file_lists + if __name__ == "__main__": """ Usage: @@ -220,11 +231,7 @@ def prepare_directories(base_save_dir, codec_model_name, manifests, audio_base_d rank = torch.distributed.get_rank() if rank == 0: save_dir, file_lists = prepare_directories( - args.save_dir, - args.codec_model_name, - manifests, - audio_base_dirs, - dataset_names + args.save_dir, args.codec_model_name, manifests, audio_base_dirs, dataset_names ) results = [save_dir, file_lists] else: @@ -233,11 +240,7 @@ def prepare_directories(base_save_dir, codec_model_name, manifests, audio_base_d save_dir, file_lists = results else: save_dir, file_lists = prepare_directories( - args.save_dir, - args.codec_model_name, - manifests, - audio_base_dirs, - dataset_names + args.save_dir, args.codec_model_name, manifests, audio_base_dirs, dataset_names ) codec_extractor = CodecExtractor(args.codec_model_path) diff --git a/scripts/magpietts/dpo/create_preference_pairs.py b/scripts/magpietts/dpo/create_preference_pairs.py index 7519100c2641..fc9deed2d69c 100644 --- a/scripts/magpietts/dpo/create_preference_pairs.py +++ b/scripts/magpietts/dpo/create_preference_pairs.py @@ -43,7 +43,12 @@ def main(): ) parser.add_argument("--group_size", type=int, default=4) parser.add_argument("--cer_threshold", type=float, default=0.02) - parser.add_argument("--min_length_threshold", type=float, default=1.5, help="Minimum length permitted. Set this shorter to allow very short sentences (which can be useful for DPO tuning.") + parser.add_argument( + "--min_length_threshold", + type=float, + default=1.5, + help="Minimum length permitted. Set this shorter to allow very short sentences (which can be useful for DPO tuning.", + ) parser.add_argument("--val_size", type=int, default=64) args = parser.parse_args() @@ -83,7 +88,9 @@ def main(): all_best_records, all_worst_records = create_chosen_rejected_records(records, group_size, num_chosen_per_group) print("Len all_best_records: ", len(all_best_records)) print("Len all_worst_records: ", len(all_worst_records)) - best_records, worst_records = filter_best_and_worst_records(all_best_records, all_worst_records, args.cer_threshold, args.min_length_threshold) + best_records, worst_records = filter_best_and_worst_records( + all_best_records, all_worst_records, args.cer_threshold, args.min_length_threshold + ) print("Len filtered best_records: ", len(best_records)) print("Len filtered worst_records: ", len(worst_records)) worst_records = normalize_rejected_rewards(worst_records) @@ -167,7 +174,7 @@ def pareto_rank(items): # A: (cerA, ssimA), B: (cerB, ssimB) def is_dominated(A, B): assert len(A) == 2 - assert len(B) == 2 + assert len(B) == 2 return (B[0] <= A[0]) and (B[1] >= A[1]) and (B != A) # Equivalently, check at least one strict inequality: # (B[0] < A[0]) or (B[1] > A[1]) @@ -303,6 +310,7 @@ def create_chosen_rejected_records(records_orig, group_size=6, num_chosen_per_gr print(f"Skipped {num_skipped} records due to invalid entries.") return best_records, worst_records + def filter_best_and_worst_records(best_records, worst_records, cer_threshold=0.02, min_length_threshold=1.5): ridx = 0 filtered_best_records = [] @@ -315,7 +323,9 @@ def filter_best_and_worst_records(best_records, worst_records, cer_threshold=0.0 best_record = best_records[ridx] if best_record['cer_gts'] < cer_threshold: worst_record = worst_records[ridx] - if (worst_record['duration'] > 19.0 or best_record['duration'] > 19.0) or (worst_record['duration'] < min_length_threshold or best_record['duration'] < min_length_threshold): + if (worst_record['duration'] > 19.0 or best_record['duration'] > 19.0) or ( + worst_record['duration'] < min_length_threshold or best_record['duration'] < min_length_threshold + ): skipped_records += 1 ridx += 1 continue diff --git a/scripts/magpietts/dpo/create_text_contextpairs.py b/scripts/magpietts/dpo/create_text_contextpairs.py index 2e7e1ee1184e..029d44235f38 100644 --- a/scripts/magpietts/dpo/create_text_contextpairs.py +++ b/scripts/magpietts/dpo/create_text_contextpairs.py @@ -149,6 +149,7 @@ def create_audio_context_record(text, audio_context, record_type): return record + def create_text_context_record(text, text_context, record_type): """ Creates a record for a text-context pair with text context. @@ -164,11 +165,11 @@ def create_text_context_record(text, text_context, record_type): if text_context.endswith("\n"): text_context = text_context[:-1] record = { - 'text' : text, - 'duration' : 6.0, # Does not matter, avoids filtering out in DPO, + 'text': text, + 'duration': 6.0, # Does not matter, avoids filtering out in DPO, 'audio_filepath': text_context.split(",")[1], - 'context_text' : text_context.split(",")[0], - 'record_type' : record_type # challenging or regular + 'context_text': text_context.split(",")[0], + 'record_type': record_type, # challenging or regular } if text_context.split(",")[-1].endswith(".pt"): record['target_audio_codes_path'] = text_context.split(",")[-1] diff --git a/scripts/magpietts/eval_squimmos.py b/scripts/magpietts/eval_squimmos.py index 047ee0b743b0..06a57152314e 100644 --- a/scripts/magpietts/eval_squimmos.py +++ b/scripts/magpietts/eval_squimmos.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE -import os -import json -import torch import argparse +import json +import os + import librosa -import scipy.stats as stats import numpy as np +import scipy.stats as stats +import torch + def find_sample_audios(audio_dir): file_list = [] @@ -29,6 +31,7 @@ def find_sample_audios(audio_dir): file_list = [t[1] for t in file_list] return file_list + def compute_mean_and_confidence_interval(measurements, confidence=0.95): mean = np.mean(measurements) std_err = stats.sem(measurements) @@ -37,10 +40,15 @@ def compute_mean_and_confidence_interval(measurements, confidence=0.95): return "{:.4f} +/- {:.4f}".format(mean, confidence_interval), mean, confidence_interval + def main(): parser = argparse.ArgumentParser(description='Evaluate Squim MOS') parser.add_argument('--exp_base_dir', type=str, default="/datap/misc/ContinuousEvalResults/NewTransformerKoelTTS") - parser.add_argument('--audio_dirs', type=str, default="svencoder_small_sp_ks3_onlyphoneme_epoch242_Temp0.6_Topk80_Cfg_False_1.0_libri_val") + parser.add_argument( + '--audio_dirs', + type=str, + default="svencoder_small_sp_ks3_onlyphoneme_epoch242_Temp0.6_Topk80_Cfg_False_1.0_libri_val", + ) args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index 4add09cafc78..4e2f58575010 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -13,44 +13,44 @@ # limitations under the License. dataset_meta_info = { 'riva_hard_digits': { - 'manifest_path' : '/Data/evaluation_manifests/hard-digits-path-corrected.ndjson', - 'audio_dir' : '/Data/RIVA-TTS', - 'feature_dir' : '/Data/RIVA-TTS', + 'manifest_path': '/Data/evaluation_manifests/hard-digits-path-corrected.ndjson', + 'audio_dir': '/Data/RIVA-TTS', + 'feature_dir': '/Data/RIVA-TTS', }, 'riva_hard_letters': { - 'manifest_path' : '/Data/evaluation_manifests/hard-letters-path-corrected.ndjson', - 'audio_dir' : '/Data/RIVA-TTS', - 'feature_dir' : '/Data/RIVA-TTS', + 'manifest_path': '/Data/evaluation_manifests/hard-letters-path-corrected.ndjson', + 'audio_dir': '/Data/RIVA-TTS', + 'feature_dir': '/Data/RIVA-TTS', }, 'riva_hard_money': { - 'manifest_path' : '/Data/evaluation_manifests/hard-money-path-corrected.ndjson', - 'audio_dir' : '/Data/RIVA-TTS', - 'feature_dir' : '/Data/RIVA-TTS', + 'manifest_path': '/Data/evaluation_manifests/hard-money-path-corrected.ndjson', + 'audio_dir': '/Data/RIVA-TTS', + 'feature_dir': '/Data/RIVA-TTS', }, 'riva_hard_short': { - 'manifest_path' : '/Data/evaluation_manifests/hard-short-path-corrected.ndjson', - 'audio_dir' : '/Data/RIVA-TTS', - 'feature_dir' : '/Data/RIVA-TTS', + 'manifest_path': '/Data/evaluation_manifests/hard-short-path-corrected.ndjson', + 'audio_dir': '/Data/RIVA-TTS', + 'feature_dir': '/Data/RIVA-TTS', }, 'vctk': { - 'manifest_path' : '/Data/evaluation_manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths_silence_trimmed.json', - 'audio_dir' : '/Data/VCTK-Corpus-0.92', - 'feature_dir' : '/Data/VCTK-Corpus-0.92', + 'manifest_path': '/Data/evaluation_manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths_silence_trimmed.json', + 'audio_dir': '/Data/VCTK-Corpus-0.92', + 'feature_dir': '/Data/VCTK-Corpus-0.92', }, 'libritts_seen': { - 'manifest_path' : '/Data/evaluation_manifests/LibriTTS_seen_evalset_from_testclean_v2.json', - 'audio_dir' : '/Data/LibriTTS', - 'feature_dir' : '/Data/LibriTTS', + 'manifest_path': '/Data/evaluation_manifests/LibriTTS_seen_evalset_from_testclean_v2.json', + 'audio_dir': '/Data/LibriTTS', + 'feature_dir': '/Data/LibriTTS', }, 'libritts_test_clean': { - 'manifest_path' : '/Data/evaluation_manifests/LibriTTS_test_clean_withContextAudioPaths.jsonl', - 'audio_dir' : '/Data/LibriTTS', - 'feature_dir' : '/Data/LibriTTS', + 'manifest_path': '/Data/evaluation_manifests/LibriTTS_test_clean_withContextAudioPaths.jsonl', + 'audio_dir': '/Data/LibriTTS', + 'feature_dir': '/Data/LibriTTS', }, # We need an4_val_ci just for CI tests 'an4_val_ci': { - 'manifest_path' : '/home/TestData/an4_dataset/an4_val_context_v1.json', - 'audio_dir' : '/', - 'feature_dir' : None, + 'manifest_path': '/home/TestData/an4_dataset/an4_val_context_v1.json', + 'audio_dir': '/', + 'feature_dir': None, }, -} \ No newline at end of file +} diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 9e2bf08e5544..d846705ff8e5 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -13,26 +13,26 @@ # limitations under the License. import argparse import json +import logging import os import pprint import string -import logging +import tempfile from contextlib import contextmanager from functools import partial -import soundfile as sf -import tempfile +import librosa import numpy as np +import scripts.magpietts.evalset_config as evalset_config +import soundfile as sf import torch +from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector, WhisperForConditionalGeneration, WhisperProcessor import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate_detail -from nemo.collections.tts.modules.fcd_metric import FrechetCodecDistance from nemo.collections.tts.models import AudioCodecModel -from transformers import WhisperProcessor, WhisperForConditionalGeneration -import librosa -import scripts.magpietts.evalset_config as evalset_config -from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector +from nemo.collections.tts.modules.fcd_metric import FrechetCodecDistance + def find_generated_files(audio_dir, prefix, extension): file_list = [] @@ -44,12 +44,15 @@ def find_generated_files(audio_dir, prefix, extension): file_list = [t[1] for t in file_list] return file_list + def find_generated_audio_files(audio_dir): return find_generated_files(audio_dir=audio_dir, prefix="predicted_audio", extension=".wav") + def find_generated_codec_files(audio_dir): return find_generated_files(audio_dir=audio_dir, prefix="predicted_codes", extension=".pt") + def read_manifest(manifest_path): records = [] with open(manifest_path, 'r') as f: @@ -59,6 +62,7 @@ def read_manifest(manifest_path): records.append(json.loads(line)) return records + def process_text(input_text): # Convert text to lowercase lower_case_text = input_text.lower() @@ -76,10 +80,13 @@ def process_text(input_text): return single_space_text + def transcribe_with_whisper(whisper_model, whisper_processor, audio_path, language, device): speech_array, sampling_rate = librosa.load(audio_path, sr=16000) # Set the language task (optional, improves performance for specific languages) - forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None + forced_decoder_ids = ( + whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None + ) inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features inputs = inputs.to(device) # Generate transcription @@ -91,6 +98,7 @@ def transcribe_with_whisper(whisper_model, whisper_processor, audio_path, langua result = transcription[0] return result + def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_seconds: float) -> np.ndarray: """ Pad audio to make it at least `min_seconds` long by adding silence at the end if needed. @@ -107,6 +115,7 @@ def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_second audio_np = np.pad(audio_np, (0, padding_needed), mode='constant', constant_values=0) return audio_np + @contextmanager def nemo_log_level(level): """ @@ -125,6 +134,7 @@ def nemo_log_level(level): # restore the original level when the context manager is exited (even if an exception was raised) logger.setLevel(original_level) + def extract_embedding(model, extractor, audio_path, device, sv_model_type): speech_array, sampling_rate = librosa.load(audio_path, sr=16000) # pad to 0.5 seconds as the extractor may not be able to handle very short signals @@ -134,7 +144,7 @@ def extract_embedding(model, extractor, audio_path, device, sv_model_type): with torch.inference_mode(): embeddings = model(inputs).embeddings else: # Titanet - with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: + with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: # the embedding model doesn't accept NumPy arrays, so we write to a temporary file sf.write(temp_file.name, speech_array, samplerate=16000) with torch.inference_mode(): @@ -142,8 +152,16 @@ def extract_embedding(model, extractor, audio_path, device, sv_model_type): return embeddings.squeeze() -def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", - codecmodel_path=None): + +def evaluate( + manifest_path, + audio_dir, + generated_audio_dir, + language="en", + sv_model_type="titanet", + asr_model_name="stt_en_conformer_transducer_large", + codecmodel_path=None, +): audio_file_lists = find_generated_audio_files(generated_audio_dir) records = read_manifest(manifest_path) assert len(audio_file_lists) == len(records) @@ -171,13 +189,17 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo speaker_verification_model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv').to(device).eval() else: feature_extractor = None - speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name='titanet_large' + ) speaker_verification_model = speaker_verification_model.to(device) speaker_verification_model.eval() with nemo_log_level(logging.ERROR): # The model `titanet_small` prints thousands of lines during initialization, so suppress logs temporarily print("Loading `titanet_small` model...") - speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_small') + speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name='titanet_small' + ) speaker_verification_model_alternate = speaker_verification_model_alternate.to(device) speaker_verification_model_alternate.eval() @@ -207,7 +229,7 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo context_audio_filepath = os.path.join(audio_dir, context_audio_filepath) # Update the FCD metric for *real* codes if fcd_metric is not None: - fcd_metric.update_from_audio_file(gt_audio_filepath, True) + fcd_metric.update_from_audio_file(gt_audio_filepath, True) pred_audio_filepath = audio_file_lists[ridx] if fcd_metric is not None: @@ -221,9 +243,13 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0].text gt_audio_text = process_text(gt_audio_text) else: - pred_text = transcribe_with_whisper(whisper_model, whisper_processor, pred_audio_filepath, language, device) + pred_text = transcribe_with_whisper( + whisper_model, whisper_processor, pred_audio_filepath, language, device + ) pred_text = process_text(pred_text) - gt_audio_text = transcribe_with_whisper(whisper_model, whisper_processor, gt_audio_filepath, language, device) + gt_audio_text = transcribe_with_whisper( + whisper_model, whisper_processor, gt_audio_filepath, language, device + ) gt_audio_text = process_text(gt_audio_text) except Exception as e: print("Error during ASR: {}".format(e)) @@ -251,60 +277,95 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo # update FCD metric if fcd_metric is not None: - predicted_codes = torch.load(pred_codes_filepath).unsqueeze(0) # B, C, T + predicted_codes = torch.load(pred_codes_filepath).unsqueeze(0) # B, C, T predicted_codes_lens = torch.tensor([predicted_codes.size(-1)], dtype=torch.int, device=device) fcd_metric.update(predicted_codes, predicted_codes_lens, False) pred_context_ssim = 0.0 gt_context_ssim = 0.0 with torch.inference_mode(): - extract_embedding_fn = partial(extract_embedding, model=speaker_verification_model, extractor=feature_extractor, device=device, sv_model_type=sv_model_type) - extract_embedding_fn_alternate = partial(extract_embedding, model=speaker_verification_model_alternate, extractor=feature_extractor, device=device, sv_model_type=sv_model_type) + extract_embedding_fn = partial( + extract_embedding, + model=speaker_verification_model, + extractor=feature_extractor, + device=device, + sv_model_type=sv_model_type, + ) + extract_embedding_fn_alternate = partial( + extract_embedding, + model=speaker_verification_model_alternate, + extractor=feature_extractor, + device=device, + sv_model_type=sv_model_type, + ) # Ground truth vs. predicted gt_speaker_embedding = extract_embedding_fn(audio_path=gt_audio_filepath) pred_speaker_embedding = extract_embedding_fn(audio_path=pred_audio_filepath) - pred_gt_ssim = torch.nn.functional.cosine_similarity(gt_speaker_embedding, pred_speaker_embedding, dim=0).item() + pred_gt_ssim = torch.nn.functional.cosine_similarity( + gt_speaker_embedding, pred_speaker_embedding, dim=0 + ).item() # Ground truth vs. predicted (alternate model) gt_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=gt_audio_filepath) pred_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=pred_audio_filepath) - pred_gt_ssim_alternate = torch.nn.functional.cosine_similarity(gt_speaker_embedding_alternate, pred_speaker_embedding_alternate, dim=0).item() + pred_gt_ssim_alternate = torch.nn.functional.cosine_similarity( + gt_speaker_embedding_alternate, pred_speaker_embedding_alternate, dim=0 + ).item() if context_audio_filepath is not None: context_speaker_embedding = extract_embedding_fn(audio_path=context_audio_filepath) context_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=context_audio_filepath) - + # Predicted vs. context - pred_context_ssim = torch.nn.functional.cosine_similarity(pred_speaker_embedding, context_speaker_embedding, dim=0).item() + pred_context_ssim = torch.nn.functional.cosine_similarity( + pred_speaker_embedding, context_speaker_embedding, dim=0 + ).item() # Ground truth vs. context - gt_context_ssim = torch.nn.functional.cosine_similarity(gt_speaker_embedding, context_speaker_embedding, dim=0).item() + gt_context_ssim = torch.nn.functional.cosine_similarity( + gt_speaker_embedding, context_speaker_embedding, dim=0 + ).item() # Predicted vs. context (alternate model) - pred_context_ssim_alternate = torch.nn.functional.cosine_similarity(pred_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() + pred_context_ssim_alternate = torch.nn.functional.cosine_similarity( + pred_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0 + ).item() # Ground truth vs. context (alternate model) - gt_context_ssim_alternate = torch.nn.functional.cosine_similarity(gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0).item() - - filewise_metrics.append({ - 'gt_text': gt_text, - 'pred_text': pred_text, - 'gt_audio_text': gt_audio_text, - 'detailed_cer': detailed_cer, - 'detailed_wer': detailed_wer, - 'cer': detailed_cer[0], - 'wer': detailed_wer[0], - 'pred_gt_ssim': pred_gt_ssim, - 'pred_context_ssim': pred_context_ssim, - 'gt_context_ssim': gt_context_ssim, - 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, - 'pred_context_ssim_alternate': pred_context_ssim_alternate, - 'gt_context_ssim_alternate': gt_context_ssim_alternate, - 'gt_audio_filepath': gt_audio_filepath, - 'pred_audio_filepath': pred_audio_filepath, - 'context_audio_filepath': context_audio_filepath - }) - - filewise_metrics_keys_to_save = ['cer', 'wer', 'pred_context_ssim', 'pred_text', 'gt_text', 'gt_audio_filepath', 'pred_audio_filepath', 'context_audio_filepath'] + gt_context_ssim_alternate = torch.nn.functional.cosine_similarity( + gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0 + ).item() + + filewise_metrics.append( + { + 'gt_text': gt_text, + 'pred_text': pred_text, + 'gt_audio_text': gt_audio_text, + 'detailed_cer': detailed_cer, + 'detailed_wer': detailed_wer, + 'cer': detailed_cer[0], + 'wer': detailed_wer[0], + 'pred_gt_ssim': pred_gt_ssim, + 'pred_context_ssim': pred_context_ssim, + 'gt_context_ssim': gt_context_ssim, + 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, + 'pred_context_ssim_alternate': pred_context_ssim_alternate, + 'gt_context_ssim_alternate': gt_context_ssim_alternate, + 'gt_audio_filepath': gt_audio_filepath, + 'pred_audio_filepath': pred_audio_filepath, + 'context_audio_filepath': context_audio_filepath, + } + ) + + filewise_metrics_keys_to_save = [ + 'cer', + 'wer', + 'pred_context_ssim', + 'pred_text', + 'gt_text', + 'gt_audio_filepath', + 'pred_audio_filepath', + 'context_audio_filepath', + ] filtered_filewise_metrics = [] for m in filewise_metrics: filtered_filewise_metrics.append({k: m[k] for k in filewise_metrics_keys_to_save}) @@ -323,21 +384,36 @@ def evaluate(manifest_path, audio_dir, generated_audio_dir, language="en", sv_mo avg_metrics['cer_filewise_avg'] = sum([m['detailed_cer'][0] for m in filewise_metrics]) / len(filewise_metrics) avg_metrics['wer_filewise_avg'] = sum([m['detailed_wer'][0] for m in filewise_metrics]) / len(filewise_metrics) avg_metrics['cer_cumulative'] = word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=True)[0] - avg_metrics['wer_cumulative'] = word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=False)[0] + avg_metrics['wer_cumulative'] = word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=False)[ + 0 + ] avg_metrics['ssim_pred_gt_avg'] = sum([m['pred_gt_ssim'] for m in filewise_metrics]) / len(filewise_metrics) - avg_metrics['ssim_pred_context_avg'] = sum([m['pred_context_ssim'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['ssim_pred_context_avg'] = sum([m['pred_context_ssim'] for m in filewise_metrics]) / len( + filewise_metrics + ) avg_metrics['ssim_gt_context_avg'] = sum([m['gt_context_ssim'] for m in filewise_metrics]) / len(filewise_metrics) - avg_metrics['ssim_pred_gt_avg_alternate'] = sum([m['pred_gt_ssim_alternate'] for m in filewise_metrics]) / len(filewise_metrics) - avg_metrics['ssim_pred_context_avg_alternate'] = sum([m['pred_context_ssim_alternate'] for m in filewise_metrics]) / len(filewise_metrics) - avg_metrics['ssim_gt_context_avg_alternate'] = sum([m['gt_context_ssim_alternate'] for m in filewise_metrics]) / len(filewise_metrics) - avg_metrics["cer_gt_audio_cumulative"] = word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=True)[0] - avg_metrics["wer_gt_audio_cumulative"] = word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=False)[0] + avg_metrics['ssim_pred_gt_avg_alternate'] = sum([m['pred_gt_ssim_alternate'] for m in filewise_metrics]) / len( + filewise_metrics + ) + avg_metrics['ssim_pred_context_avg_alternate'] = sum( + [m['pred_context_ssim_alternate'] for m in filewise_metrics] + ) / len(filewise_metrics) + avg_metrics['ssim_gt_context_avg_alternate'] = sum( + [m['gt_context_ssim_alternate'] for m in filewise_metrics] + ) / len(filewise_metrics) + avg_metrics["cer_gt_audio_cumulative"] = word_error_rate_detail( + hypotheses=gt_audio_texts, references=gt_texts, use_cer=True + )[0] + avg_metrics["wer_gt_audio_cumulative"] = word_error_rate_detail( + hypotheses=gt_audio_texts, references=gt_texts, use_cer=False + )[0] avg_metrics["frechet_codec_distance"] = fcd pprint.pprint(avg_metrics) return avg_metrics, filewise_metrics + def main(): # audio_dir="/datap/misc/Datasets/riva" \ parser = argparse.ArgumentParser(description='Evaluate Generated Audio') @@ -354,8 +430,14 @@ def main(): args.manifest_path = dataset_meta_info[args.evalset]['manifest_path'] args.audio_dir = dataset_meta_info[args.evalset]['audio_dir'] - evaluate(args.manifest_path, args.audio_dir, args.generated_audio_dir, args.whisper_language, sv_model_type="wavlm", asr_model_name="nvidia/parakeet-ctc-0.6b") - + evaluate( + args.manifest_path, + args.audio_dir, + args.generated_audio_dir, + args.whisper_language, + sv_model_type="wavlm", + asr_model_name="nvidia/parakeet-ctc-0.6b", + ) if __name__ == "__main__": diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index f4de0f7eea8e..ef19897b1e4e 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -18,28 +18,31 @@ import os import shutil import time -from typing import List -from pathlib import Path from functools import partial +from pathlib import Path +from typing import List -import scripts.magpietts.evalset_config as evalset_config -import scripts.magpietts.evaluate_generated_audio as evaluate_generated_audio +import matplotlib.pyplot as plt import numpy as np +import pandas as pd import scipy.stats as stats +import scripts.magpietts.evalset_config as evalset_config +import scripts.magpietts.evaluate_generated_audio as evaluate_generated_audio import soundfile as sf import torch from omegaconf.omegaconf import OmegaConf, open_dict from PIL import Image -import matplotlib.pyplot as plt -import pandas as pd from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer from nemo.collections.tts.data.text_to_speech_dataset import MagpieTTSDataset from nemo.collections.tts.models import MagpieTTSModel -from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer # EVALUATION_DATASETS is the full list of datasets for evaluation of a new model. -EVALUATION_DATASETS = "riva_hard_digits,riva_hard_letters,riva_hard_money,riva_hard_short,vctk,libritts_seen,libritts_test_clean" +EVALUATION_DATASETS = ( + "riva_hard_digits,riva_hard_letters,riva_hard_money,riva_hard_short,vctk,libritts_seen,libritts_test_clean" +) + def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0.90): metrics = {} @@ -53,8 +56,9 @@ def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0 metrics[key] = "{:.4f} +/- {:.4f}".format(mean, confidence_interval) return metrics + def update_config(model_cfg, codecmodel_path, legacy_codebooks=False, legacy_text_conditioning=False): - ''' helper function to rename older yamls from t5 to magpie ''' + '''helper function to rename older yamls from t5 to magpie''' model_cfg.codecmodel_path = codecmodel_path if hasattr(model_cfg, 'text_tokenizer'): # Backward compatibility for models trained with absolute paths in text_tokenizer @@ -84,7 +88,9 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False, legacy_tex if legacy_codebooks: # Added to address backward compatibility arising from # https://github.com/blisc/NeMo/pull/64 - print("WARNING: Using legacy codebook indices for backward compatibility. Should only be used with old checkpoints.") + print( + "WARNING: Using legacy codebook indices for backward compatibility. Should only be used with old checkpoints." + ) num_audio_tokens_per_codebook = model_cfg.num_audio_tokens_per_codebook model_cfg.forced_num_all_tokens_per_codebook = num_audio_tokens_per_codebook model_cfg.forced_audio_eos_id = num_audio_tokens_per_codebook - 1 @@ -104,6 +110,7 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False, legacy_tex sample_rate = None return model_cfg, sample_rate + def update_ckpt(state_dict): new_state_dict = {} for key in state_dict.keys(): @@ -126,6 +133,7 @@ def delete_old_generated_files(output_dir): for f in glob.glob(f"{output_dir}/predicted_audio*.wav"): os.remove(f) + def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: str): # Create dataframe from list of dicts df = pd.DataFrame(metrics) @@ -139,9 +147,7 @@ def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: assert column in df # Create empty lists to store the parts objects for each DataFrame # Plot the violin plots for each DataFrame - axs[i].violinplot( - df[column], showmedians=True, positions=[i], widths=0.5 - ) + axs[i].violinplot(df[column], showmedians=True, positions=[i], widths=0.5) axs[i].set_title(column) axs[i].set_xticks([i]) @@ -151,14 +157,7 @@ def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: # Calculate and display the mean value for each DataFrame mean = df[column].mean() sem = df[column].sem() - axs[i].plot( - i, - mean, - "o", - color="red", - markersize=4, - label="Mean (95%CI)" - ) + axs[i].plot(i, mean, "o", color="red", markersize=4, label="Mean (95%CI)") label_numeric = f"{mean:.2f}±{1.96 * sem:.2f}" axs[i].text(i + 0.06, mean, label_numeric, ha="center", va="top") @@ -174,7 +173,7 @@ def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: def create_combined_violin_plots(dataset_metrics: dict, metric_keys: List[str], output_png: str): """ Create box plots comparing multiple datasets for each metric in a single figure. - + Args: dataset_metrics: Dictionary where keys are dataset names and values are lists of metric dictionaries metric_keys: List of metric names to plot @@ -184,25 +183,25 @@ def create_combined_violin_plots(dataset_metrics: dict, metric_keys: List[str], datasets = list(dataset_metrics.keys()) num_datasets = len(datasets) num_metrics = len(metric_keys) - + # Create figure with subplots for each metric fig, axs = plt.subplots(1, num_metrics, figsize=(num_metrics * 6, 6)) - + # Handle case where there's only one metric (axs won't be an array) if num_metrics == 1: axs = [axs] - + # Define colors for different datasets colors = plt.cm.Set3(np.linspace(0, 1, num_datasets)) - + for metric_idx, metric in enumerate(metric_keys): ax = axs[metric_idx] - + # Collect data for all datasets for this metric all_data = [] positions = [] dataset_labels = [] - + for dataset_idx, dataset in enumerate(datasets): df = pd.DataFrame(dataset_metrics[dataset]) if metric in df.columns: @@ -210,26 +209,32 @@ def create_combined_violin_plots(dataset_metrics: dict, metric_keys: List[str], all_data.append(data) positions.append(dataset_idx + 1) dataset_labels.append(dataset) - + # Create box plots if all_data: - bp = ax.boxplot(all_data, positions=positions, widths=0.6, patch_artist=True, - showmeans=True, meanline=False, meanprops={'marker': 'o', 'markerfacecolor': 'red', - 'markeredgecolor': 'red', 'markersize': 6}) - + bp = ax.boxplot( + all_data, + positions=positions, + widths=0.6, + patch_artist=True, + showmeans=True, + meanline=False, + meanprops={'marker': 'o', 'markerfacecolor': 'red', 'markeredgecolor': 'red', 'markersize': 6}, + ) + # Color the box plots for i, patch in enumerate(bp['boxes']): patch.set_facecolor(colors[i]) patch.set_alpha(0.7) - + # Add mean labels for each dataset for i, (data, pos) in enumerate(zip(all_data, positions)): mean = data.mean() sem = data.sem() - + label_numeric = f"{mean:.3f}±{1.96 * sem:.3f}" ax.text(pos + 0.1, mean, label_numeric, ha="left", va="center", fontsize=8) - + # Set labels and title ax.set_title(f"{metric.upper()}", fontsize=12, fontweight='bold') ax.set_xticks(positions) @@ -237,14 +242,14 @@ def create_combined_violin_plots(dataset_metrics: dict, metric_keys: List[str], ax.grid(True, linestyle="dotted", alpha=0.7) ax.set_xlabel("Dataset") ax.set_ylabel(metric) - + # Set y-axis limit for CER metrics if 'cer' in metric.lower(): ax.set_ylim(0, 0.3) - + # Add overall title fig.suptitle("Performance Comparison Across Datasets", fontsize=14, fontweight='bold') - + # Adjust layout and save plt.tight_layout() plt.savefig(output_png, format="png", bbox_inches="tight", dpi=300) @@ -253,40 +258,40 @@ def create_combined_violin_plots(dataset_metrics: dict, metric_keys: List[str], def run_inference( - hparams_file, - checkpoint_file, - nemo_file, - datasets, - out_dir, - temperature, - topk, - codecmodel_path, - use_cfg, - cfg_scale, - batch_size, - sv_model, - asr_model_name, - num_repeats=1, - apply_attention_prior=False, - attention_prior_epsilon=1e-3, - attention_prior_lookahead_window=10, - estimate_alignment_from_layers=None, - apply_prior_to_layers=None, - start_prior_after_n_audio_steps=10, - confidence_level=0.95, - use_local_transformer=False, - maskgit_n_steps=3, - maskgit_noise_scale=0.0, - maskgit_fixed_schedule=None, - maskgit_sampling_type=None, - legacy_codebooks=False, - legacy_text_conditioning=False, - clean_up_disk=False, - hparams_file_from_wandb=False, - log_exp_name=False, - compute_fcd=False, - violin_plot_metrics=['cer', 'pred_context_ssim'] - ): + hparams_file, + checkpoint_file, + nemo_file, + datasets, + out_dir, + temperature, + topk, + codecmodel_path, + use_cfg, + cfg_scale, + batch_size, + sv_model, + asr_model_name, + num_repeats=1, + apply_attention_prior=False, + attention_prior_epsilon=1e-3, + attention_prior_lookahead_window=10, + estimate_alignment_from_layers=None, + apply_prior_to_layers=None, + start_prior_after_n_audio_steps=10, + confidence_level=0.95, + use_local_transformer=False, + maskgit_n_steps=3, + maskgit_noise_scale=0.0, + maskgit_fixed_schedule=None, + maskgit_sampling_type=None, + legacy_codebooks=False, + legacy_text_conditioning=False, + clean_up_disk=False, + hparams_file_from_wandb=False, + log_exp_name=False, + compute_fcd=False, + violin_plot_metrics=['cer', 'pred_context_ssim'], +): # Load model if hparams_file is not None and checkpoint_file is not None: model_cfg = OmegaConf.load(hparams_file) @@ -297,7 +302,9 @@ def run_inference( model_cfg = model_cfg.value with open_dict(model_cfg): - model_cfg, cfg_sample_rate = update_config(model_cfg, codecmodel_path, legacy_codebooks, legacy_text_conditioning) + model_cfg, cfg_sample_rate = update_config( + model_cfg, codecmodel_path, legacy_codebooks, legacy_text_conditioning + ) model = MagpieTTSModel(cfg=model_cfg) model.use_kv_cache_for_inference = True @@ -311,7 +318,9 @@ def run_inference( elif nemo_file is not None: model_cfg = MagpieTTSModel.restore_from(nemo_file, return_config=True) with open_dict(model_cfg): - model_cfg, cfg_sample_rate = update_config(model_cfg, codecmodel_path, legacy_codebooks, legacy_text_conditioning) + model_cfg, cfg_sample_rate = update_config( + model_cfg, codecmodel_path, legacy_codebooks, legacy_text_conditioning + ) model = MagpieTTSModel.restore_from(nemo_file, override_config_path=model_cfg) model.use_kv_cache_for_inference = True checkpoint_name = nemo_file.split("/")[-1].split(".nemo")[0] @@ -343,7 +352,7 @@ def run_inference( f"{attention_prior_epsilon}_{attention_prior_lookahead_window}_{start_prior_after_n_audio_steps}_" f"{''.join([str(l) for l in estimate_alignment_from_layers]) if estimate_alignment_from_layers is not None else 'None'}_" f"{''.join([str(l) for l in apply_prior_to_layers]) if apply_prior_to_layers is not None else 'None'}_" - ) + ) checkpoint_name += ( f"LT_{use_local_transformer}_" f"MaskGit_{maskgit_n_steps}_{maskgit_sampling_type}_{''.join([str(l) for l in maskgit_fixed_schedule]) if maskgit_fixed_schedule is not None else 'None'}_" @@ -383,7 +392,7 @@ def run_inference( context_duration_max = model.cfg.get('context_duration_max', 5.0) if context_duration_min < 5.0 and context_duration_max > 5.0: context_duration_min = 5.0 - context_duration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though. + context_duration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though. dataset_filewise_metrics_all_repeats = [] # Store metrics for all repeats of this dataset for repeat_idx in range(num_repeats): @@ -415,7 +424,9 @@ def run_inference( context_duration_min=context_duration_min, context_duration_max=context_duration_max, ) - assert len(test_dataset) == len(manifest_records), f"Dataset length and manifest length should be the same. Dataset length: {len(test_dataset)}, Manifest length: {len(manifest_records)}" + assert len(test_dataset) == len( + manifest_records + ), f"Dataset length and manifest length should be the same. Dataset length: {len(test_dataset)}, Manifest length: {len(manifest_records)}" test_dataset.text_tokenizer = model.tokenizer # Set phoneme prob = 1 for g2p @@ -426,7 +437,6 @@ def run_inference( g2p = model.tokenizer.g2p if g2p is not None: g2p.phoneme_probability = 1.0 - test_data_loader = torch.utils.data.DataLoader( test_dataset, @@ -449,7 +459,15 @@ def run_inference( batch_cuda[key] = batch[key] st = time.time() - predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, rtf_metrics, cross_attention_maps, _ = model.infer_batch( + ( + predicted_audio, + predicted_audio_lens, + predicted_codes, + predicted_codes_lens, + rtf_metrics, + cross_attention_maps, + _, + ) = model.infer_batch( batch_cuda, max_decoder_steps=440, temperature=temperature, @@ -467,7 +485,7 @@ def run_inference( maskgit_n_steps=maskgit_n_steps, maskgit_noise_scale=maskgit_noise_scale, maskgit_fixed_schedule=maskgit_fixed_schedule, - maskgit_sampling_type=maskgit_sampling_type + maskgit_sampling_type=maskgit_sampling_type, ) all_rtf_metrics.append(rtf_metrics) @@ -478,11 +496,11 @@ def run_inference( cross_attn_map_image.save(os.path.join(audio_dir, f"cross_attn_map_{item_idx}.png")) predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() - predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]] + predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] audio_path = os.path.join(pred_audio_dir, f"predicted_audio_{item_idx}.wav") sf.write(audio_path, predicted_audio_np, model.sample_rate) codes_path = os.path.join(pred_audio_dir, f"predicted_codes_{item_idx}.pt") - predicted_codes_current = predicted_codes[idx, :, :predicted_codes_lens[idx]] # C, T' + predicted_codes_current = predicted_codes[idx, :, : predicted_codes_lens[idx]] # C, T' torch.save(predicted_codes_current, codes_path) codec_file_paths.append(codes_path) context_audio_path = manifest_records[item_idx].get('context_audio_filepath', None) @@ -508,11 +526,13 @@ def run_inference( language=language, sv_model_type=sv_model, asr_model_name=asr_model_name, - codecmodel_path=codecmodel_path if compute_fcd else None + codecmodel_path=codecmodel_path if compute_fcd else None, ) metrics_n_repeated.append(metrics) - dataset_filewise_metrics_all_repeats.extend(filewise_metrics) # Collect all filewise metrics for combined plot - + dataset_filewise_metrics_all_repeats.extend( + filewise_metrics + ) # Collect all filewise metrics for combined plot + with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: json.dump(metrics, f, indent=4) @@ -541,14 +561,25 @@ def run_inference( # Store filewise metrics for this dataset for combined plotting all_datasets_filewise_metrics[dataset] = dataset_filewise_metrics_all_repeats - metric_keys = ['cer_filewise_avg', 'wer_filewise_avg', 'cer_cumulative', 'wer_cumulative', - 'ssim_pred_gt_avg', 'ssim_pred_context_avg', 'ssim_gt_context_avg', - 'ssim_pred_gt_avg_alternate', 'ssim_pred_context_avg_alternate', 'ssim_gt_context_avg_alternate', - 'cer_gt_audio_cumulative', 'wer_gt_audio_cumulative' - ] + metric_keys = [ + 'cer_filewise_avg', + 'wer_filewise_avg', + 'cer_cumulative', + 'wer_cumulative', + 'ssim_pred_gt_avg', + 'ssim_pred_context_avg', + 'ssim_gt_context_avg', + 'ssim_pred_gt_avg_alternate', + 'ssim_pred_context_avg_alternate', + 'ssim_gt_context_avg_alternate', + 'cer_gt_audio_cumulative', + 'wer_gt_audio_cumulative', + ] if compute_fcd: metric_keys.append('frechet_codec_distance') - metrics_mean_ci = compute_mean_and_confidence_interval(metrics_n_repeated, metric_keys, confidence=confidence_level) + metrics_mean_ci = compute_mean_and_confidence_interval( + metrics_n_repeated, metric_keys, confidence=confidence_level + ) all_experiment_csv_with_ci = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") if not os.path.exists(all_experiment_csv_with_ci): with open(all_experiment_csv_with_ci, "w") as f: @@ -565,7 +596,6 @@ def run_inference( f.write(data) print(f"Wrote metrics with CI for {checkpoint_name} and {dataset} to {all_experiment_csv_with_ci}") - measurements = [m['ssim_pred_context_avg'] for m in metrics_n_repeated] ssim_current = np.mean(measurements) ssim_per_dataset.append(ssim_current) @@ -577,7 +607,7 @@ def run_inference( if len(all_datasets_filewise_metrics) > 1: # Only create combined plot if we have multiple datasets combined_output_png = os.path.join(out_dir, f"{checkpoint_name}_combined_violin_plot.png") create_combined_violin_plots(all_datasets_filewise_metrics, violin_plot_metrics, combined_output_png) - + # Average across datasets ssim = np.mean(ssim_per_dataset) cer = np.mean(cer_per_dataset) @@ -598,11 +628,17 @@ def main(): parser.add_argument('--out_dir', type=str, default="/datap/misc/Evals/LocalTransformerAblations2") parser.add_argument('--temperature', type=float, default=0.6) parser.add_argument('--use_cfg', action='store_true') - parser.add_argument('--use_local_transformer', action='store_true', help="Enables use of local transformer for inference; applies to both Autoregressive and MaskGit sampling.") + parser.add_argument( + '--use_local_transformer', + action='store_true', + help="Enables use of local transformer for inference; applies to both Autoregressive and MaskGit sampling.", + ) parser.add_argument('--maskgit_n_steps', type=int, default=3) parser.add_argument('--maskgit_noise_scale', type=float, default=0.0) parser.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None) - parser.add_argument('--maskgit_sampling_type', default=None, choices=["default", "causal", "purity_causal", "purity_default"]) + parser.add_argument( + '--maskgit_sampling_type', default=None, choices=["default", "causal", "purity_causal", "purity_default"] + ) parser.add_argument('--cfg_scale', type=float, default=2.5) parser.add_argument('--apply_attention_prior', action='store_true') parser.add_argument('--attention_prior_epsilon', type=float, default=0.1) @@ -613,8 +649,10 @@ def main(): parser.add_argument('--topk', type=int, default=80) parser.add_argument('--batch_size', type=int, default=32) # Parameters for evaluation - parser.add_argument('--sv_model', type=str, default="titanet") # titanet, wavlm - parser.add_argument('--asr_model_name', type=str, default="nvidia/parakeet-tdt-1.1b") # stt_en_conformer_transducer_large, nvidia/parakeet-ctc-0.6b + parser.add_argument('--sv_model', type=str, default="titanet") # titanet, wavlm + parser.add_argument( + '--asr_model_name', type=str, default="nvidia/parakeet-tdt-1.1b" + ) # stt_en_conformer_transducer_large, nvidia/parakeet-ctc-0.6b parser.add_argument('--num_repeats', type=int, default=1) parser.add_argument('--confidence_level', type=float, default=0.95) parser.add_argument('--legacy_codebooks', action='store_true') @@ -622,9 +660,19 @@ def main(): parser.add_argument('--clean_up_disk', action='store_true') parser.add_argument('--cer_target', type=float, default=None) parser.add_argument('--ssim_target', type=float, default=None) - parser.add_argument('--log_exp_name', action='store_true', help="Include the experiment name (derived from the checkpoint path) in the output folder name.") + parser.add_argument( + '--log_exp_name', + action='store_true', + help="Include the experiment name (derived from the checkpoint path) in the output folder name.", + ) parser.add_argument('--disable_fcd', action='store_true', help="Disable Frechet Codec Distance computation") - parser.add_argument('--violin_plot_metrics', type=str, nargs='*', default=['cer','pred_context_ssim'], help="Which metrics to add the violin plot.") + parser.add_argument( + '--violin_plot_metrics', + type=str, + nargs='*', + default=['cer', 'pred_context_ssim'], + help="Which metrics to add the violin plot.", + ) args = parser.parse_args() if args.datasets is None: @@ -671,16 +719,23 @@ def main(): hparams_file_from_wandb=args.hparams_file_from_wandb, log_exp_name=args.log_exp_name, compute_fcd=compute_fcd, - violin_plot_metrics=args.violin_plot_metrics + violin_plot_metrics=args.violin_plot_metrics, ) # Mode 1: Run inference from provided hparams and checkpoint files - if (args.hparams_files is not None) and (args.checkpoint_files is not None) and (args.hparams_files != "null") and (args.checkpoint_files != "null"): + if ( + (args.hparams_files is not None) + and (args.checkpoint_files is not None) + and (args.hparams_files != "null") + and (args.checkpoint_files != "null") + ): hparam_files = args.hparams_files.split(",") checkpoint_files = args.checkpoint_files.split(",") print("Running inference for hparams files: ", hparam_files) print("Running inference for checkpoint files: ", checkpoint_files) - assert len(hparam_files) == len(checkpoint_files), "Number of hparams files and checkpoint files should be the same." + assert len(hparam_files) == len( + checkpoint_files + ), "Number of hparams files and checkpoint files should be the same." for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): cer, ssim = run_inference_w_args( hparams_file=hparams_file, @@ -710,4 +765,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/scripts/tts_dataset_to_lhotse/create_shars.py b/scripts/tts_dataset_to_lhotse/create_shars.py index f680ba80c877..aa72e7fd16c4 100644 --- a/scripts/tts_dataset_to_lhotse/create_shars.py +++ b/scripts/tts_dataset_to_lhotse/create_shars.py @@ -11,26 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path +import argparse +import csv import json import os -import argparse import shutil -import csv -import soundfile as sf +from pathlib import Path + ### from nemo.collections.tts.models import AudioCodecModel import librosa -import torch import numpy as np -from tqdm import tqdm -from matplotlib import pyplot as plt - -from lhotse import CutSet, SupervisionSegment, Recording, AudioSource +import soundfile as sf +import torch +from lhotse import AudioSource, CutSet, Recording, SupervisionSegment +from lhotse.array import Array, TemporalArray +from lhotse.audio import RecordingSet from lhotse.cut.base import Cut from lhotse.features.base import Features, FeatureSet -from lhotse.array import TemporalArray, Array from lhotse.shar.writers import AudioTarWriter -from lhotse.audio import RecordingSet +from matplotlib import pyplot as plt +from tqdm import tqdm + def json_reader(filename): with open(filename) as f: @@ -38,9 +39,7 @@ def json_reader(filename): yield json.loads(line) -def create_shar_from_manifest( - manifest, audio_root_path, out_shar_dir, shard_size=1000 -): +def create_shar_from_manifest(manifest, audio_root_path, out_shar_dir, shard_size=1000): in_manifest = list(json_reader(manifest)) print(f"...loaded {manifest} # of datapoints {len(in_manifest)}") num_shard = int(len(in_manifest) / shard_size) @@ -86,7 +85,6 @@ def create_shar_from_manifest( # Language target target_language.append(language) - print("Done extracting data from manifest") print(len(user_recordings)) cuts = CutSet.from_manifests(recordings=RecordingSet.from_recordings(user_recordings)) @@ -122,22 +120,19 @@ def create_shar_from_manifest( out_shar_dir.mkdir(parents=True, exist_ok=True) shard_size = shard_size assert len(user_recordings) % shard_size != 0, "Lhotse breaks if feat_list is a multiple of shard_size" - exported = cuts.to_shar( - out_shar_dir, fields={"recording": "wav"}, num_jobs=4, shard_size=shard_size - ) + exported = cuts.to_shar(out_shar_dir, fields={"recording": "wav"}, num_jobs=4, shard_size=shard_size) print(f"...share created") print(f"...Exporting target_audio to tar files") for i, path in tqdm(enumerate(exported["cuts"])): path = path[0] out_path = path.replace("cuts", "target_audio").replace(".jsonl.gz", ".tar") - with AudioTarWriter( - out_path, shard_size=None, format="flac" - ) as writer: + with AudioTarWriter(out_path, shard_size=None, format="flac") as writer: for cut in CutSet.from_file(path): writer.write(cut.id, cut.target_audio.load_audio(), manifest=cut.target_audio, sampling_rate=22050) print(f"...Exported target_audio to tar files") + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -174,7 +169,6 @@ def main(): shard_size=args.shard_size, ) + if __name__ == "__main__": main() - - diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index 140d44415577..e03a73b3c24c 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -344,6 +344,7 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path): assert batches[4]["audio"].shape == (4, 8000) # exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts + @pytest.mark.pleasefixme def test_dataloader_from_lhotse_cuts_pad_min_duration(cutset_path: Path): config = OmegaConf.create( @@ -1938,6 +1939,7 @@ def test_multimodal_text_audio_dataloading_randomized_round_robin_strategy( assert torch.is_tensor(ex.answer_ids) assert torch.is_tensor(ex.mask) + @pytest.mark.pleasefixme def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): config = OmegaConf.create( @@ -1967,6 +1969,7 @@ def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: assert isinstance(cut, MixedCut) assert -5.0 < cut.tracks[1].snr < 5.0 + @pytest.mark.pleasefixme def test_dataloader_with_noise_lhotse_jsonl(cutset_path: Path): config = OmegaConf.create( @@ -1996,6 +1999,7 @@ def test_dataloader_with_noise_lhotse_jsonl(cutset_path: Path): assert isinstance(cut, MixedCut) assert -5.0 < cut.tracks[1].snr < 5.0 + @pytest.mark.pleasefixme def test_dataloader_with_noise_nemo_tar(cutset_path: Path, nemo_tarred_manifest_path_multi: Path): noise_json, noise_tar = nemo_tarred_manifest_path_multi diff --git a/tests/collections/tts/modules/test_fcd_metric.py b/tests/collections/tts/modules/test_fcd_metric.py index dc87548ddd8f..25f9a4227934 100644 --- a/tests/collections/tts/modules/test_fcd_metric.py +++ b/tests/collections/tts/modules/test_fcd_metric.py @@ -15,8 +15,8 @@ import pytest import torch -from nemo.collections.tts.modules.fcd_metric import FrechetCodecDistance from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.modules.fcd_metric import FrechetCodecDistance class TestFrechetCodecDistance: From 932da5edcaf0414744220bed418f784cc197fd05 Mon Sep 17 00:00:00 2001 From: Jason Date: Mon, 25 Aug 2025 13:31:54 -0400 Subject: [PATCH 079/113] Add attended text token to attention plots (#14560) * add fix for nemo logging stack level Signed-off-by: Jason * add change to log predicted text token in attention plots Signed-off-by: Jason * Apply isort and black reformatting Signed-off-by: blisc * address lint issues Signed-off-by: Jason * Apply isort and black reformatting Signed-off-by: blisc * update magpie tests Signed-off-by: Jason --------- Signed-off-by: Jason Signed-off-by: blisc Co-authored-by: blisc --- .github/workflows/cicd-main-automodel.yml | 152 ++++++++-------- .github/workflows/cicd-main-speech.yml | 186 +++++++++++--------- .github/workflows/cicd-main.yml | 2 +- nemo/collections/tts/models/magpietts.py | 33 ++-- nemo/collections/tts/parts/utils/helpers.py | 5 +- nemo/utils/nemo_logging.py | 13 +- 6 files changed, 207 insertions(+), 184 deletions(-) diff --git a/.github/workflows/cicd-main-automodel.yml b/.github/workflows/cicd-main-automodel.yml index edc0ceb26d0a..89b8addd0261 100644 --- a/.github/workflows/cicd-main-automodel.yml +++ b/.github/workflows/cicd-main-automodel.yml @@ -59,79 +59,79 @@ jobs: cpu-only: ${{ matrix.cpu-only || false }} is_optional: ${{ matrix.is-optional || false }} - e2e-tests: - strategy: - fail-fast: false - matrix: - include: - - runner: self-hosted-azure-gpus-1 - script: L2_VLM_HF_Transformer_PEFT - - runner: self-hosted-azure - script: L2_VLM_HF_Transformer_PEFT_FSDP2 - - runner: self-hosted-azure-gpus-1 - script: L2_VLM_HF_Transformer_PEFT_4bit - is-optional: true - - runner: self-hosted-azure - script: L2_VLM_HF_Transformer_SFT_FSDP2 - - runner: self-hosted-azure - script: L2_HF_Transformer_PEFT_notebook - - runner: self-hosted-azure-gpus-1 - script: L2_HF_Transformer_PEFT - - runner: self-hosted-azure-gpus-1 - script: L2_HF_Transformer_PEFT_nemorun - - runner: self-hosted-azure - script: L2_HF_Transformer_PEFT_2gpu - - runner: self-hosted-azure - script: L2_HF_Transformer_PEFT_2gpu_FSDP2_liger - - runner: azure-gpu-vm-runner1-h100 - script: L2_HF_Transformer_PEFT_2gpu_FSDP2_fp8 - - runner: self-hosted-azure - script: L2_HF_Transformer_PEFT_2gpu_FSDP2 - - runner: self-hosted-azure - script: L2_HF_Transformer_PEFT_2gpu_nemorun - - runner: self-hosted-azure - script: L2_HF_Transformer_SFT_2gpu - - runner: self-hosted-azure - script: L2_HF_Transformer_SFT_2gpu_FSDP2 - - runner: azure-gpu-vm-runner1-h100 - script: L2_HF_Transformer_SFT_2gpu_FSDP2_fp8 - - runner: self-hosted-azure - script: L2_HF_Transformer_SFT_2gpu_nemorun - - runner: self-hosted-azure - script: L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2 - - runner: self-hosted-azure - script: L2_HF_Transformer_SFT_FSDP2_2gpu - - runner: self-hosted-azure - script: L2_HF_Transformer_PT_2gpu - - runner: self-hosted-azure - script: L2_HF_Transformer_PT_2gpu_nemorun - - runner: self-hosted-azure-gpus-1 - script: L2_HF_Transformer_PT - - runner: self-hosted-azure-gpus-1 - script: L2_HF_Transformer_PT_nemorun - - runner: self-hosted-azure - script: L2_HF_Transformer_SFT_notebook - - runner: self-hosted-azure-gpus-1 - script: L2_HF_Transformer_SFT - - runner: self-hosted-azure-gpus-1 - script: L2_HF_Transformer_SFT_nemorun - - runner: self-hosted-azure-gpus-1 - script: L2_HF_Transformer_SFT_TE_Acceleration - - runner: self-hosted-azure-gpus-1 - script: L2_HF_Transformer_PT_TE_Acceleration - needs: [unit-tests] - runs-on: ${{ matrix.runner }} - name: ${{ matrix.is-optional && 'PLEASEFIXME_' || '' }}${{ matrix.script }} - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main - with: - runner: ${{ runner.name }} - script: ${{ matrix.script }} - tests_to_run: ${{ inputs.test_to_run }} - image: ${{ inputs.image-name }} - is_optional: ${{ matrix.is-optional || false }} + # e2e-tests: disabled for magpie dev banch + # strategy: + # fail-fast: false + # matrix: + # include: + # - runner: self-hosted-azure-gpus-1 + # script: L2_VLM_HF_Transformer_PEFT + # - runner: self-hosted-azure + # script: L2_VLM_HF_Transformer_PEFT_FSDP2 + # - runner: self-hosted-azure-gpus-1 + # script: L2_VLM_HF_Transformer_PEFT_4bit + # is-optional: true + # - runner: self-hosted-azure + # script: L2_VLM_HF_Transformer_SFT_FSDP2 + # - runner: self-hosted-azure + # script: L2_HF_Transformer_PEFT_notebook + # - runner: self-hosted-azure-gpus-1 + # script: L2_HF_Transformer_PEFT + # - runner: self-hosted-azure-gpus-1 + # script: L2_HF_Transformer_PEFT_nemorun + # - runner: self-hosted-azure + # script: L2_HF_Transformer_PEFT_2gpu + # - runner: self-hosted-azure + # script: L2_HF_Transformer_PEFT_2gpu_FSDP2_liger + # - runner: azure-gpu-vm-runner1-h100 + # script: L2_HF_Transformer_PEFT_2gpu_FSDP2_fp8 + # - runner: self-hosted-azure + # script: L2_HF_Transformer_PEFT_2gpu_FSDP2 + # - runner: self-hosted-azure + # script: L2_HF_Transformer_PEFT_2gpu_nemorun + # - runner: self-hosted-azure + # script: L2_HF_Transformer_SFT_2gpu + # - runner: self-hosted-azure + # script: L2_HF_Transformer_SFT_2gpu_FSDP2 + # - runner: azure-gpu-vm-runner1-h100 + # script: L2_HF_Transformer_SFT_2gpu_FSDP2_fp8 + # - runner: self-hosted-azure + # script: L2_HF_Transformer_SFT_2gpu_nemorun + # - runner: self-hosted-azure + # script: L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2 + # - runner: self-hosted-azure + # script: L2_HF_Transformer_SFT_FSDP2_2gpu + # - runner: self-hosted-azure + # script: L2_HF_Transformer_PT_2gpu + # - runner: self-hosted-azure + # script: L2_HF_Transformer_PT_2gpu_nemorun + # - runner: self-hosted-azure-gpus-1 + # script: L2_HF_Transformer_PT + # - runner: self-hosted-azure-gpus-1 + # script: L2_HF_Transformer_PT_nemorun + # - runner: self-hosted-azure + # script: L2_HF_Transformer_SFT_notebook + # - runner: self-hosted-azure-gpus-1 + # script: L2_HF_Transformer_SFT + # - runner: self-hosted-azure-gpus-1 + # script: L2_HF_Transformer_SFT_nemorun + # - runner: self-hosted-azure-gpus-1 + # script: L2_HF_Transformer_SFT_TE_Acceleration + # - runner: self-hosted-azure-gpus-1 + # script: L2_HF_Transformer_PT_TE_Acceleration + # needs: [unit-tests] + # runs-on: ${{ matrix.runner }} + # name: ${{ matrix.is-optional && 'PLEASEFIXME_' || '' }}${{ matrix.script }} + # steps: + # - name: Checkout + # uses: actions/checkout@v4 + # with: + # path: ${{ github.run_id }} + # - name: main + # uses: NVIDIA/NeMo/.github/actions/test-template@main + # with: + # runner: ${{ runner.name }} + # script: ${{ matrix.script }} + # tests_to_run: ${{ inputs.test_to_run }} + # image: ${{ inputs.image-name }} + # is_optional: ${{ matrix.is-optional || false }} diff --git a/.github/workflows/cicd-main-speech.yml b/.github/workflows/cicd-main-speech.yml index a9febfc115f9..986bff428e67 100644 --- a/.github/workflows/cicd-main-speech.yml +++ b/.github/workflows/cicd-main-speech.yml @@ -84,102 +84,112 @@ jobs: fail-fast: false matrix: include: - - runner: self-hosted-azure-gpus-1 - script: ASR_dev_run_Speech_to_Text - - runner: self-hosted-azure-gpus-1 - script: ASR_dev_run_Speech_to_Text_WPE_CitriNet - - runner: self-hosted-azure-gpus-1 - script: ASR_dev_run_Speech_Pre-training_-_CitriNet - - runner: self-hosted-azure-gpus-1 - script: Optional_ASR_dev_run_Speech_To_Text_Finetuning - is-optional: true - - runner: self-hosted-azure-gpus-1 - script: Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning - is-optional: true - - runner: self-hosted-azure-gpus-1 - script: ASR_dev_run_Speech_to_Text_WPE_-_Conformer - - runner: self-hosted-azure-gpus-1 - script: ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer - - runner: self-hosted-azure-gpus-1 - script: L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader - - runner: self-hosted-azure-gpus-1 - script: L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader - - runner: self-hosted-azure-gpus-1 - script: L2_ASR_Adapters_Linear_Adapters - - runner: self-hosted-azure-gpus-1 - script: L2_ASR_Adapters_RelPos_MHA_Adapters + # - runner: self-hosted-azure-gpus-1. selectively run some L2 speech tests for magpie dev branch + # script: ASR_dev_run_Speech_to_Text + # - runner: self-hosted-azure-gpus-1 + # script: ASR_dev_run_Speech_to_Text_WPE_CitriNet + # - runner: self-hosted-azure-gpus-1 + # script: ASR_dev_run_Speech_Pre-training_-_CitriNet + # - runner: self-hosted-azure-gpus-1 + # script: Optional_ASR_dev_run_Speech_To_Text_Finetuning + # is-optional: true + # - runner: self-hosted-azure-gpus-1 + # script: Optional_ASR_dev_run_Speech_To_Text_HF_Finetuning + # is-optional: true + # - runner: self-hosted-azure-gpus-1 + # script: ASR_dev_run_Speech_to_Text_WPE_-_Conformer + # - runner: self-hosted-azure-gpus-1 + # script: ASR_dev_run-part_two_Speech_to_Text_WPE_-_Squeezeformer + # - runner: self-hosted-azure-gpus-1 + # script: L2_ASR_Multi-dataloader_dev_run_Speech_to_Text_multi-dataloader + # - runner: self-hosted-azure-gpus-1 + # script: L2_ASR_Multi-dataloader_dev_run_Speech_to_Label_multi-dataloader + # - runner: self-hosted-azure-gpus-1 + # script: L2_ASR_Adapters_Linear_Adapters + # - runner: self-hosted-azure-gpus-1 + # script: L2_ASR_Adapters_RelPos_MHA_Adapters + # - runner: self-hosted-azure + # script: L2_Speech_to_Text_EMA + # - runner: self-hosted-azure-gpus-1 + # script: L2_Speech_to_Text_AED + # - runner: self-hosted-azure-gpus-1 + # script: L2_Speaker_dev_run_Speech_to_Label + # - runner: self-hosted-azure + # script: L2_Speech_Estimate_Duration_Bins + # - runner: self-hosted-azure + # script: L2_Speech_Batch_Size_OOMptimizer + # - runner: self-hosted-azure + # script: Optional_L2_Speech_Batch_Size_OOMptimizer_Canary + # is-optional: true + # - runner: self-hosted-azure + # script: L2_Speech_Transcription_Speech_to_Text_Transcribe + # - runner: self-hosted-azure + # script: L2_Speech_Transcription_Canary_Transcribe_Full_Manifest + # - runner: self-hosted-azure + # script: L2_Speech_Transcription_Canary_Transcribe_With_Prompt + # - runner: self-hosted-azure + # script: L2_Speech_Transcription_Canary_Transcribe_Audio_Dir + # - runner: self-hosted-azure + # script: L2_Longform_Speech_Transcription_Canary_Chunked_Infer_from_Audio_Dir + # - runner: self-hosted-azure + # script: L2_Longform_Speech_Transcription_with_TimeStamps_Canary_Chunked_Infer_from_Audio_Dir + # - runner: self-hosted-azure + # script: L2_Longform_Speech_Transcription_with_TimeStamps_Canary_Chunked_Infer_from_Manifest + # - runner: self-hosted-azure-gpus-1 + # script: Speech_Checkpoints_tests + # timeout: 20 + # - runner: self-hosted-azure-gpus-1 + # script: L2_Speaker_dev_run_Speaker_Recognition + # - runner: self-hosted-azure-gpus-1 + # script: L2_Speaker_dev_run_Speaker_Diarization + # - runner: self-hosted-azure-gpus-1 + # script: L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer + # - runner: self-hosted-azure + # script: L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference + # - runner: self-hosted-azure + # script: L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference + # - runner: self-hosted-azure + # script: L2_Speaker_dev_run_Clustering_Diarizer_Inference + # - runner: self-hosted-azure + # script: L2_Speaker_dev_run_Neural_Diarizer_Inference + # - runner: self-hosted-azure + # script: L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation + # - runner: self-hosted-azure + # script: L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav + # - runner: self-hosted-azure + # script: L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3 + # - script: L2_HF_Transformer_SpeechLM_SFT_2gpu + # runner: self-hosted-azure + # - script: L2_SpeechLM_LoRA_TP1PP1_MBS2 + # runner: self-hosted-azure + # - runner: self-hosted-azure-gpus-1 + # script: L2_TTS_Fast_dev_runs_1_Tacotron_2 + # - runner: self-hosted-azure + # script: L2_TTS_Fast_dev_runs_1_WaveGlow + # - runner: self-hosted-azure + # script: L2_TTS_Fast_dev_runs_1_FastPitch + # - runner: self-hosted-azure + # script: L2_TTS_Fast_dev_runs_1_Hifigan + # - runner: self-hosted-azure + # script: L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference + # - runner: self-hosted-azure + # script: L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference - runner: self-hosted-azure - script: L2_Speech_to_Text_EMA - - runner: self-hosted-azure-gpus-1 - script: L2_Speech_to_Text_AED - - runner: self-hosted-azure-gpus-1 - script: L2_Speaker_dev_run_Speech_to_Label - - runner: self-hosted-azure - script: L2_Speech_Estimate_Duration_Bins - - runner: self-hosted-azure - script: L2_Speech_Batch_Size_OOMptimizer - - runner: self-hosted-azure - script: Optional_L2_Speech_Batch_Size_OOMptimizer_Canary - is-optional: true - - runner: self-hosted-azure - script: L2_Speech_Transcription_Speech_to_Text_Transcribe - - runner: self-hosted-azure - script: L2_Speech_Transcription_Canary_Transcribe_Full_Manifest - - runner: self-hosted-azure - script: L2_Speech_Transcription_Canary_Transcribe_With_Prompt - - runner: self-hosted-azure - script: L2_Speech_Transcription_Canary_Transcribe_Audio_Dir - - runner: self-hosted-azure - script: L2_Longform_Speech_Transcription_Canary_Chunked_Infer_from_Audio_Dir - - runner: self-hosted-azure - script: L2_Longform_Speech_Transcription_with_TimeStamps_Canary_Chunked_Infer_from_Audio_Dir - - runner: self-hosted-azure - script: L2_Longform_Speech_Transcription_with_TimeStamps_Canary_Chunked_Infer_from_Manifest - - runner: self-hosted-azure-gpus-1 - script: Speech_Checkpoints_tests - timeout: 20 - - runner: self-hosted-azure-gpus-1 - script: L2_Speaker_dev_run_Speaker_Recognition - - runner: self-hosted-azure-gpus-1 - script: L2_Speaker_dev_run_Speaker_Diarization - - runner: self-hosted-azure-gpus-1 - script: L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer - - runner: self-hosted-azure - script: L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference - - runner: self-hosted-azure - script: L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference - - runner: self-hosted-azure - script: L2_Speaker_dev_run_Clustering_Diarizer_Inference - - runner: self-hosted-azure - script: L2_Speaker_dev_run_Neural_Diarizer_Inference - - runner: self-hosted-azure - script: L2_Speaker_dev_run_Multispeaker_ASR_Data_Simulation - - runner: self-hosted-azure - script: L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav - - runner: self-hosted-azure - script: L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Ru_QN_with_mp3 - - script: L2_HF_Transformer_SpeechLM_SFT_2gpu - runner: self-hosted-azure - - script: L2_SpeechLM_LoRA_TP1PP1_MBS2 - runner: self-hosted-azure - - runner: self-hosted-azure-gpus-1 - script: L2_TTS_Fast_dev_runs_1_Tacotron_2 - - runner: self-hosted-azure - script: L2_TTS_Fast_dev_runs_1_WaveGlow + script: SPEECHLM_HF_Training_DuplexS2S - runner: self-hosted-azure - script: L2_TTS_Fast_dev_runs_1_FastPitch + script: SPEECHLM_HF_Training_DuplexS2SSpeechDecoder - runner: self-hosted-azure - script: L2_TTS_Fast_dev_runs_1_Hifigan + script: SPEECHLM_HF_Training_SALM - runner: self-hosted-azure - script: L2_G2P_Models_G2P_Conformer_training_evaluation_and_inference + script: L2_TTS_Fast_dev_runs_Magpietts_DecoderContext - runner: self-hosted-azure - script: L2_G2P_Models_HeteronymClassificationModel_training_evaluation_and_inference + script: L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder - runner: self-hosted-azure - script: SPEECHLM_HF_Training_DuplexS2S + script: L2_TTS_Fast_dev_runs_Magpietts_OnlinePO - runner: self-hosted-azure - script: SPEECHLM_HF_Training_DuplexS2SSpeechDecoder + script: L2_TTS_InferEvaluate_Magpietts_ZeroShot - runner: self-hosted-azure - script: SPEECHLM_HF_Training_SALM + script: L2_TTS_InferEvaluate_Magpietts_SeenSpeakers needs: [unit-tests] runs-on: ${{ matrix.runner }} name: ${{ matrix.is-optional && 'PLEASEFIXME_' || '' }}${{ matrix.script }} diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 7e3e686df5de..d786790950ea 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -325,7 +325,7 @@ jobs: - cicd-import-tests - L0_Setup_Test_Data_And_Models - cicd-main-unit-tests - - cicd-main-nemo2 + # - cicd-main-nemo2. not needed for magpie dev branch - cicd-main-automodel - cicd-main-speech if: always() diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 16d28be553e6..fb9a060e6561 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -597,7 +597,7 @@ def maskgit_apply_random_mask(self, codes): # Randomly replaces some codes with the MASK_TOKEN with a proportion following the cosine schedule. # Codes: (B, C, T) mask = self.maskgit_create_random_mask(codes) - ## replace some tokens with MASK_TOKEN + # replace some tokens with MASK_TOKEN codes_with_mask = torch.where(mask, self.mask_token_id, codes) return codes_with_mask, mask @@ -799,7 +799,7 @@ def local_transformer_sample_maskgit( actual_batch_size = topk_indices.size(0) // 2 assert ( topk_indices[actual_batch_size:] == topk_indices[:actual_batch_size] - ).all(), f"Topk indices are not the same for conditional and unconditional codes" + ).all(), "Topk indices are not the same for conditional and unconditional codes" # replace masks of the top-k confident codebooks with the codes that were sampled for them unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices) @@ -899,7 +899,7 @@ def local_transformer_sample_maskgit( codes = sampled_codes assert not ( codes == self.mask_token_id - ).any(), f"Codes contain mask tokens after completion of MaskGit sampling" + ).any(), "Codes contain mask tokens after completion of MaskGit sampling" # break stacked groups of frames into individual frames codes = codes.reshape(B, self.frame_stacking_factor, self.num_audio_codebooks).permute( @@ -1628,7 +1628,7 @@ def process_batch(self, batch, mode="train"): local_transformer_logits = None if self.local_transformer_type != LocalTransformerType.NO_LT: if self.local_transformer_type == LocalTransformerType.MASKGIT: - ## Maskgit ## + # Maskgit # randomly replace some positions with MASK_TOKEN audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target_unstacked) # TODO @rfejgin: the very last position might be padding but the local transformer might look at it as part of @@ -1644,7 +1644,7 @@ def process_batch(self, batch, mode="train"): frame_stacking_factor=self.frame_stacking_factor, ) else: - ## Autoregressive ## + # Autoregressive assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type" local_transformer_logits = self.compute_local_transformer_logits( dec_out[:, dec_context_size:, :], audio_codes_target_unstacked, targets_offset_by_one=False @@ -1807,7 +1807,7 @@ def validation_step(self, batch, batch_idx): [batch_output['aligner_attn_soft']], audio_codes_lens_target, text_lens, - prefix=f"val/aligner_encoder_attn", + prefix="val/aligner_encoder_attn", ) ) @@ -1817,7 +1817,7 @@ def validation_step(self, batch, batch_idx): [batch_output['aligner_attn_hard'].unsqueeze(1)], audio_codes_lens_target, text_lens, - prefix=f"val/aligner_encoder_attn_hard", + prefix="val/aligner_encoder_attn_hard", ) ) @@ -1884,7 +1884,7 @@ def get_most_attended_text_timestep( item_attention_scores = alignment_attention_scores[bidx, last_attended_timestep:window_end] if item_attention_scores.size(0) == 0: # This means the sentence has ended - attended_timestep = text_lens[bidx] - 1 + attended_timestep = text_lens[bidx].item() - 1 else: attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep text_time_step_attended.append(attended_timestep) @@ -1958,7 +1958,9 @@ def get_inference_attention_plots( predicted_codes_lens, batch_size, compute_all_heads_attn_maps, + last_attended_timestep, ): + last_attended_timestep = np.array(last_attended_timestep).T cross_attention_scores_all_timesteps = torch.stack( cross_attention_scores_all_timesteps, dim=2 ) # B, text_timesteps, T' @@ -1975,7 +1977,10 @@ def get_inference_attention_plots( item_cross_attention_scores = cross_attention_scores_all_timesteps[ bidx, : text_lens[bidx], : predicted_codes_lens[bidx] ] - cross_attn_np = plot_alignment_to_numpy(item_cross_attention_scores.cpu().numpy()) + cross_attn_np = plot_alignment_to_numpy( + item_cross_attention_scores.cpu().numpy(), + attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]], + ) cross_attention_maps.append(cross_attn_np) item_all_head_cross_attn_maps = [] if compute_all_heads_attn_maps: @@ -1984,7 +1989,8 @@ def get_inference_attention_plots( bidx, : text_lens[bidx], : predicted_codes_lens[bidx] ] headwise_cross_attn_np = plot_alignment_to_numpy( - item_headwise_cross_attention_scores.cpu().numpy() + item_headwise_cross_attention_scores.cpu().numpy(), + attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]], ) item_all_head_cross_attn_maps.append(headwise_cross_attn_np) headwise_cross_attention_maps.append(item_all_head_cross_attn_maps) @@ -2300,6 +2306,7 @@ def infer_batch( predicted_codes_lens, text.size(0), compute_all_heads_attn_maps, + last_attended_timesteps, ) return ( predicted_audio, @@ -2335,7 +2342,7 @@ def test_step(self, batch, batch_idx): is_tb = isinstance(logger, TensorBoardLogger) if not is_wandb and not is_tb: raise ValueError( - f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." + "Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." ) for idx in range(predicted_audio.size(0)): @@ -2345,8 +2352,8 @@ def test_step(self, batch, batch_idx): if is_wandb: log_dict = { - f"test/predicted_audio": wandb.Audio( - predicted_audio_np, sample_rate=self.sample_rate, caption=f"Predicted Audio" + "test/predicted_audio": wandb.Audio( + predicted_audio_np, sample_rate=self.sample_rate, caption="Predicted Audio" ), } logger.experiment.log(log_dict, step=item_idx) diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index 5700f66123d3..98ddb0b1ef18 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -442,12 +442,15 @@ def tacotron2_log_to_wandb_func( swriter.log({"audios": audios}) -def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vmin=None, vmax=None): +def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vmin=None, vmax=None, attended=None): if phoneme_seq: fig, ax = plt.subplots(figsize=(15, 10)) else: fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none', vmin=vmin, vmax=vmax) + if attended is not None: + for step in range(len(attended) - 1): + plt.plot([step, step + 1], [attended[step], attended[step + 1]], color='red', linewidth=1, linestyle='--') ax.set_title(title) fig.colorbar(im, ax=ax) xlabel = 'Decoder timestep' diff --git a/nemo/utils/nemo_logging.py b/nemo/utils/nemo_logging.py index bcc7ad199603..d19d5b37c30b 100644 --- a/nemo/utils/nemo_logging.py +++ b/nemo/utils/nemo_logging.py @@ -31,11 +31,14 @@ class LogMode(enum.IntEnum): + """Enum to control how many times to log messages in NeMo logging""" + EACH = 0 # Log the message each time ONCE = 1 # Log the message only once. The same message will not be logged again. class Logger(metaclass=Singleton): + """NeMo's logging class. Makes some changes on top of python's logging module to aid model devs.""" # Level 0 NOTSET = _logging.NOTSET @@ -378,7 +381,7 @@ def debug(self, msg, *args, mode=LogMode.EACH, **kwargs): logger.debug("Houston, we have a %s", "thorny problem", exc_info=1) """ if self._logger is not None and self._logger.isEnabledFor(Logger.DEBUG) and not self._logged_once(msg, mode): - self._logger._log(Logger.DEBUG, msg, args, **kwargs) + self._logger._log(Logger.DEBUG, msg, args, **kwargs, stacklevel=2) def info(self, msg, *args, mode=LogMode.EACH, **kwargs): """ @@ -390,7 +393,7 @@ def info(self, msg, *args, mode=LogMode.EACH, **kwargs): logger.info("Houston, we have a %s", "interesting problem", exc_info=1) """ if self._logger is not None and self._logger.isEnabledFor(Logger.INFO) and not self._logged_once(msg, mode): - self._logger._log(Logger.INFO, msg, args, **kwargs) + self._logger._log(Logger.INFO, msg, args, **kwargs, stacklevel=2) def warning(self, msg, *args, mode=LogMode.EACH, **kwargs): """ @@ -402,7 +405,7 @@ def warning(self, msg, *args, mode=LogMode.EACH, **kwargs): logger.warning("Houston, we have a %s", "bit of a problem", exc_info=1) """ if self._logger is not None and self._logger.isEnabledFor(Logger.WARNING) and not self._logged_once(msg, mode): - self._logger._log(Logger.WARNING, msg, args, **kwargs) + self._logger._log(Logger.WARNING, msg, args, **kwargs, stacklevel=2) def error(self, msg, *args, mode=LogMode.EACH, **kwargs): """ @@ -414,7 +417,7 @@ def error(self, msg, *args, mode=LogMode.EACH, **kwargs): logger.error("Houston, we have a %s", "major problem", exc_info=1) """ if self._logger is not None and self._logger.isEnabledFor(Logger.ERROR) and not self._logged_once(msg, mode): - self._logger._log(Logger.ERROR, msg, args, **kwargs) + self._logger._log(Logger.ERROR, msg, args, **kwargs, stacklevel=2) def critical(self, msg, *args, mode=LogMode.EACH, **kwargs): """ @@ -430,4 +433,4 @@ def critical(self, msg, *args, mode=LogMode.EACH, **kwargs): and self._logger.isEnabledFor(Logger.CRITICAL) and not self._logged_once(msg, mode) ): - self._logger._log(Logger.CRITICAL, msg, args, **kwargs) + self._logger._log(Logger.CRITICAL, msg, args, **kwargs, stacklevel=2) From 9037131f7c9736b7b97fdee62ead827606040dde Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Wed, 27 Aug 2025 13:33:39 -0700 Subject: [PATCH 080/113] [eval][bugfix] avoid overidden cross attention maps for multiple repeats. Instead, save one map per repeat. (#14577) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- scripts/magpietts/infer_and_evaluate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index ef19897b1e4e..b84f8e1a0a71 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -132,6 +132,8 @@ def delete_old_generated_files(output_dir): os.remove(f) for f in glob.glob(f"{output_dir}/predicted_audio*.wav"): os.remove(f) + for f in glob.glob(f"{output_dir}/cross_attn_map_*.png"): + os.remove(f) def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: str): @@ -493,7 +495,7 @@ def run_inference( print(f"Time taken for inference: {et-st}", predicted_audio.size()) for idx in range(predicted_audio.size(0)): cross_attn_map_image = Image.fromarray(cross_attention_maps[idx]) - cross_attn_map_image.save(os.path.join(audio_dir, f"cross_attn_map_{item_idx}.png")) + cross_attn_map_image.save(os.path.join(pred_audio_dir, f"cross_attn_map_{item_idx}.png")) predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] From 6d40bd46c7474adbcb256427c4ede0dbb65e127e Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 28 Aug 2025 08:56:15 -0400 Subject: [PATCH 081/113] Fix Magpie Tests; Update default params for Decoder Context Yamls (#14587) * update head size Signed-off-by: Jason * update speechllm vars Signed-off-by: Jason * disable ASR test Signed-off-by: Jason * disable coverage on magpie branch Signed-off-by: Jason --------- Signed-off-by: Jason --- .github/workflows/cicd-main-speech.yml | 8 +- .github/workflows/cicd-main.yml | 100 +++++++++--------- .../tts/conf/magpietts/magpietts_dc_en.yaml | 1 + .../magpietts/magpietts_lhotse_dc_en.yaml | 1 + .../magpietts_lhotse_dc_en_tiny.yaml | 2 +- .../speechlm2/models/duplex_s2s_model.py | 2 +- .../models/duplex_s2s_speech_decoder_model.py | 2 +- 7 files changed, 59 insertions(+), 57 deletions(-) diff --git a/.github/workflows/cicd-main-speech.yml b/.github/workflows/cicd-main-speech.yml index 986bff428e67..6d274271647f 100644 --- a/.github/workflows/cicd-main-speech.yml +++ b/.github/workflows/cicd-main-speech.yml @@ -38,10 +38,10 @@ jobs: - script: L0_Unit_Tests_GPU_ASR runner: self-hosted-azure-gpus-1 timeout: 30 - - script: L0_Unit_Tests_CPU_ASR - runner: self-hosted-azure-cpu - cpu-only: true - timeout: 20 + # - script: L0_Unit_Tests_CPU_ASR # Disable for magpie since we have GPU tests and this test time outs + # runner: self-hosted-azure-cpu + # cpu-only: true + # timeout: 20 - script: L0_Unit_Tests_GPU_TTS runner: self-hosted-azure-gpus-1 - script: L0_Unit_Tests_CPU_TTS diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index d786790950ea..ce61beb1456c 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -423,53 +423,53 @@ jobs: exit 1 fi - Coverage: - runs-on: ubuntu-latest - needs: [pre-flight, Nemo_CICD_Test] - if: | - needs.pre-flight.outputs.test_to_run != '[]' - && needs.pre-flight.outputs.components_to_run != '[]' - && ( - success() - || needs.Nemo_CICD_Test.result == 'success' - ) - && !cancelled() - strategy: - matrix: - flag: [unit-test, e2e] - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Download coverage reports of current branch - uses: actions/download-artifact@v4 - with: - pattern: coverage-${{ matrix.flag }}-* - - - name: Get total coverage of current branch - shell: bash -x -e -u -o pipefail {0} - if: always() - run: | - pip install coverage - - ls -al . - ls -al coverage-*/ - coverage combine --keep $(ls coverage-*/.coverage) - coverage report -i - rm -rf coverage-* - ls -al - - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v5 - with: - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - flags: ${{ matrix.flag }} - - - name: Upload artifacts - uses: actions/upload-artifact@v4 - with: - name: coverage-${{ matrix.flag }}-aggregated - path: | - .coverage - include-hidden-files: true + # Coverage: + # runs-on: ubuntu-latest + # needs: [pre-flight, Nemo_CICD_Test] + # if: | + # needs.pre-flight.outputs.test_to_run != '[]' + # && needs.pre-flight.outputs.components_to_run != '[]' + # && ( + # success() + # || needs.Nemo_CICD_Test.result == 'success' + # ) + # && !cancelled() + # strategy: + # matrix: + # flag: [unit-test, e2e] + # steps: + # - name: Checkout + # uses: actions/checkout@v4 + + # - name: Download coverage reports of current branch + # uses: actions/download-artifact@v4 + # with: + # pattern: coverage-${{ matrix.flag }}-* + + # - name: Get total coverage of current branch + # shell: bash -x -e -u -o pipefail {0} + # if: always() + # run: | + # pip install coverage + + # ls -al . + # ls -al coverage-*/ + # coverage combine --keep $(ls coverage-*/.coverage) + # coverage report -i + # rm -rf coverage-* + # ls -al + + # - name: Upload coverage reports to Codecov + # uses: codecov/codecov-action@v5 + # with: + # token: ${{ secrets.CODECOV_TOKEN }} + # verbose: true + # flags: ${{ matrix.flag }} + + # - name: Upload artifacts + # uses: actions/upload-artifact@v4 + # with: + # name: coverage-${{ matrix.flag }}-aggregated + # path: | + # .coverage + # include-hidden-files: true diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml index eef597a89f5e..9703c9d3b58b 100644 --- a/examples/tts/conf/magpietts/magpietts_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_dc_en.yaml @@ -119,6 +119,7 @@ model: has_xattn: true xa_d_memory: 768 xa_n_heads: 1 + xa_d_head: 128 is_causal: true apply_norm_to_cond: true apply_norm_out: true diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index bac847091df4..337de93ddba3 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -134,6 +134,7 @@ model: has_xattn: true xa_d_memory: 768 xa_n_heads: 1 + xa_d_head: 128 is_causal: true apply_norm_to_cond: true apply_norm_out: true diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml index b57ba7c50a45..0d26dfdb7996 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml @@ -114,7 +114,7 @@ model: d_model: 512 d_ffn: 2048 sa_n_heads: 8 - kernel_size: 1 + kernel_size: 3 p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: false diff --git a/nemo/collections/speechlm2/models/duplex_s2s_model.py b/nemo/collections/speechlm2/models/duplex_s2s_model.py index e0d90a6200d9..2de158d88be9 100644 --- a/nemo/collections/speechlm2/models/duplex_s2s_model.py +++ b/nemo/collections/speechlm2/models/duplex_s2s_model.py @@ -55,7 +55,7 @@ def __init__(self, cfg: dict) -> None: self.cfg = DictConfig(cfg) setup_audio_codec(self) - self._codebook_size = self.audio_codec.vector_quantizer.codebook_size_per_group + self._codebook_size = self.audio_codec.vector_quantizer.codebook_size self._num_codebooks = self.audio_codec.vector_quantizer.num_groups # We load the pretrained HF LLM using "ForCausalLM" variant so that we can obtain the diff --git a/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py b/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py index df246c6640b5..3605e886b3e4 100644 --- a/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py +++ b/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py @@ -56,7 +56,7 @@ def __init__(self, cfg: dict) -> None: self.cfg = DictConfig(cfg) setup_audio_codec(self) - self._codebook_size = self.audio_codec.vector_quantizer.codebook_size_per_group + self._codebook_size = self.audio_codec.vector_quantizer.codebook_size self._num_codebooks = self.audio_codec.vector_quantizer.num_groups # We load the pretrained HF LLM using "ForCausalLM" variant so that we can obtain the From 86be9f231ae558ea7187467733c14c0191aad37e Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Thu, 28 Aug 2025 13:27:13 -0700 Subject: [PATCH 082/113] [lhotse][step3] added shuffling option to nemo manifests before sharding. (#14601) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- .../create_lhotse_shar_from_nemo_manifest.py | 61 ++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py b/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py index b200168b3a67..f3acb91848f8 100644 --- a/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py +++ b/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py @@ -39,6 +39,8 @@ --processing-chunk-size ${CHUNK_SIZE} \ --audio-format ${AUDIO_FORMAT} \ --log-level ${LOG_LEVEL} \ + --shuffle \ + --shuffle-seed 42 \ 2>&1 | tee ./log/create_lhotse_shar_from_nemo_manifest.stdout Expected output: @@ -63,6 +65,7 @@ import logging import math import os +import random import re from concurrent.futures import ProcessPoolExecutor, as_completed from functools import partial @@ -239,6 +242,42 @@ def process_manifest_entry(entry: Dict[str, Any], audio_base_dir: Path) -> Tuple return None +def shuffle_jsonl_file(input_path: Path, seed: int = None) -> Path: + """ + Shuffle lines in a JSONL file and write to a shuffled copy. + + Args: + input_path: Path to the original JSONL file + seed: Random seed for reproducible shuffling + + Returns: + Path to the shuffled file + """ + if seed is not None: + random.seed(seed) + + logging.info(f"Reading and shuffling manifest entries from {input_path}") + + # Read all lines into memory + with open(input_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + + logging.info(f"Loaded {len(lines)} entries, now shuffling...") + + # Shuffle the lines + random.shuffle(lines) + + # Create output path with "_shuffled" suffix + shuffled_path = input_path.parent / f"{input_path.stem}_shuffled{input_path.suffix}" + + # Write shuffled content + with open(shuffled_path, 'w', encoding='utf-8') as f: + f.writelines(lines) + + logging.info(f"Shuffled manifest written to: {shuffled_path}") + return shuffled_path + + def chunked_iterator(iterable, chunk_size): """Yield successive chunks from iterable.""" _it = iter(iterable) @@ -425,6 +464,17 @@ def main(): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level for the main process and workers.", ) + parser.add_argument( + "--shuffle", + action="store_true", + help="Shuffle the manifest entries before processing.", + ) + parser.add_argument( + "--shuffle-seed", + type=int, + default=None, + help="Random seed for reproducible shuffling (only used if --shuffle is enabled).", + ) args = parser.parse_args() @@ -441,8 +491,15 @@ def main(): target_recordings_dir.mkdir(parents=True, exist_ok=True) context_recordings_dir.mkdir(parents=True, exist_ok=True) - logging.info(f"Reading NeMo manifest lazily from: {args.manifest_path}") - manifest_iterable = load_jsonl(args.manifest_path) + # Handle shuffling if requested + if args.shuffle: + logging.info(f"Shuffling manifest entries from: {args.manifest_path}") + shuffled_manifest_path = shuffle_jsonl_file(args.manifest_path, seed=args.shuffle_seed) + manifest_iterable = load_jsonl(shuffled_manifest_path) + logging.info(f"Using shuffled manifest for processing: {shuffled_manifest_path}") + else: + logging.info(f"Reading NeMo manifest lazily from: {args.manifest_path}") + manifest_iterable = load_jsonl(args.manifest_path) logging.info( f"Processing manifest in chunks of {args.processing_chunk_size} using {args.num_jobs} parallel workers..." From 799aaadbb550d8023eb48840d0ed2e541e180359 Mon Sep 17 00:00:00 2001 From: Jason Date: Sat, 6 Sep 2025 14:51:18 -0400 Subject: [PATCH 083/113] Always run speech CI on magpie branch regardless of changed files #14605 (#14606) * always run speech tests in branch Signed-off-by: Jason * checkout PR branch instead of HEAD branch Signed-off-by: Jason --------- Signed-off-by: Jason --- .github/actions/test-template/action.yml | 6 ++---- .github/scripts/components_to_run.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/actions/test-template/action.yml b/.github/actions/test-template/action.yml index 3344875312a5..d87a5a447dbc 100644 --- a/.github/actions/test-template/action.yml +++ b/.github/actions/test-template/action.yml @@ -79,11 +79,9 @@ runs: echo "id=$(uuidgen)" >> "$GITHUB_OUTPUT" - name: Checkout NeMo - uses: actions/checkout@v2 - env: - DIR: ${{ github.run_id }} + uses: actions/checkout@v5 with: - repository: NVIDIA/NeMo + ref: ${{ github.event.pull_request.head.ref }} path: ${{ github.run_id }}/${{steps.uuid.outputs.id }}/NeMo - name: Start container diff --git a/.github/scripts/components_to_run.py b/.github/scripts/components_to_run.py index 90ba8c5c10bd..9c9973fdae5e 100644 --- a/.github/scripts/components_to_run.py +++ b/.github/scripts/components_to_run.py @@ -69,7 +69,7 @@ def main(source_sha: str, target_sha: str): # Build dependency graph dependencies = nemo_dependencies.build_dependency_graph(nemo_root) - test_modules: List[str] = [] + test_modules: List[str] = ["speech"] # Always run speech in magpie branch for changed_file in changed_files: if changed_file in dependencies: test_modules.extend(dependencies[changed_file]) From a74a31acbb67db344934c37ec6a09b3394d5613d Mon Sep 17 00:00:00 2001 From: Jason Date: Mon, 8 Sep 2025 08:58:31 -0400 Subject: [PATCH 084/113] Whisper output normalization (#14535) Retargetted for latest Dev Branch (#14600) * merge in changes from #14535 Signed-off-by: Jason * Apply isort and black reformatting Signed-off-by: blisc * address lint comments Signed-off-by: Jason --------- Signed-off-by: Jason Signed-off-by: blisc Co-authored-by: blisc --- .../magpietts_preference_optimization.py | 47 +++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index 8c72389b6bd0..281a1d7dfe24 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -16,6 +16,7 @@ import os import random import string +from typing import Optional import librosa import numpy as np @@ -37,6 +38,8 @@ except ImportError: HAVE_TORCHAUDIO = False +from nemo_text_processing.text_normalization.normalize import Normalizer + from nemo.collections.tts.models import MagpieTTSModel @@ -65,6 +68,20 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") self.whisper_model.eval() + self._normalize_whisper_transcript = cfg.get('normalize_whisper_transcript', True) + if self._normalize_whisper_transcript: + self._normalizer_cache = {} + # Pre-create normalizer for the configured language + lang = cfg.get('pref_set_language', 'en') + self._get_cached_normalizer(lang) + + def _get_cached_normalizer(self, lang_key): + """Get or create a cached normalizer for the given language.""" + lang_key = lang_key if lang_key else "en" + if lang_key not in self._normalizer_cache: + logging.info(f"Creating normalizer for language: {lang_key}") + self._normalizer_cache[lang_key] = Normalizer(input_case="cased", lang=lang_key) + return self._normalizer_cache[lang_key] def test_step(self, batch, batch_idx): with torch.no_grad(): @@ -118,12 +135,18 @@ def test_step(self, batch, batch_idx): else: pred_transcripts = [] for audio_path in predicted_audio_paths: + normalizer = ( + self._get_cached_normalizer(self.cfg.pref_set_language) + if self._normalize_whisper_transcript + else None + ) transcript = transcribe_with_whisper( audio_path, self.cfg.pref_set_language, self.whisper_processor, self.whisper_model, self.device, + normalizer, ) pred_transcripts.append(transcript) @@ -136,7 +159,7 @@ def test_step(self, batch, batch_idx): ).any(), f"Expected short audio file to be the only cause of ASR errors, but got error with lengths {predicted_audio_lens}" logging.warning(f"Exception during ASR transcription: {e}") logging.warning( - f"Skipping processing of the batch; generating metrics indicating a WER of 100% and Speaker Similarity of 0.0" + "Skipping processing of the batch; generating metrics indicating a WER of 100% and Speaker Similarity of 0.0" ) batch_invalid = True continue # don't break since we want to continue building audio durations list @@ -505,6 +528,11 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): ) self.scale_rewards = self.cfg.get('scale_rewards', True) self.max_decoder_steps = self.cfg.get('max_decoder_steps', 430) + + self._normalize_whisper_transcript = self.cfg.get('normalize_whisper_transcript', True) + if cfg.get('reward_asr_model', "nemo") == "whisper" and self._normalize_whisper_transcript: + self._normalizer_cache = {} + # If the best record in the group is above this threshold, we will not use that group for training # Setting this to 1.0, because we clamp the ASR rewards to be in [0, 1] for OnlinePO self.best_cer_threshold = self.cfg.get('best_cer_threshold', 1.0) @@ -512,6 +540,14 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # Setting this to 1.0, because we clamp the ASR rewards to be in [0, 1] for OnlinePO self.worst_cer_threshold = self.cfg.get('worst_cer_threshold', 1.0) + def _get_cached_normalizer(self, lang_key): + """Get or create a cached normalizer for the given language.""" + lang_key = lang_key if lang_key else "en" + if lang_key not in self._normalizer_cache: + logging.info(f"Creating normalizer for language: {lang_key}") + self._normalizer_cache[lang_key] = Normalizer(input_case="cased", lang=lang_key) + return self._normalizer_cache[lang_key] + def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = super().state_dict(destination, prefix, keep_vars) keys_substrings_to_exclude = [ @@ -612,8 +648,9 @@ def generate_and_reward( pred_transcripts = [] for item_idx, audio_path in enumerate(predicted_audio_paths): language = batch_repeated['languages'][item_idx] + normalizer = self._get_cached_normalizer(language) if self._normalize_whisper_transcript else None transcript = transcribe_with_whisper( - audio_path, language, self.whisper_processor, self.whisper_model, self.device + audio_path, language, self.whisper_processor, self.whisper_model, self.device, normalizer ) pred_transcripts.append(transcript) pred_transcripts = [process_text_for_cer(transcript) for transcript in pred_transcripts] @@ -1033,7 +1070,9 @@ def get_speaker_embeddings_from_filepaths(filepaths, speaker_verification_model, return speaker_embeddings -def transcribe_with_whisper(audio_filepath, language, whisper_processor, whisper_model, device): +def transcribe_with_whisper( + audio_filepath, language, whisper_processor, whisper_model, device, normalizer: Optional[Normalizer] = None +): speech_array, sampling_rate = librosa.load(audio_filepath, sr=16000) forced_decoder_ids = ( whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None @@ -1044,4 +1083,6 @@ def transcribe_with_whisper(audio_filepath, language, whisper_processor, whisper predicted_ids = whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) result = transcription[0] + if normalizer is not None: + result = normalizer.normalize(result) return result From 81d15b11567ec590b23a17cd79191c32a5e93137 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Wed, 10 Sep 2025 07:28:47 -0700 Subject: [PATCH 085/113] Remove multilingual sentence piece tokenizer (#14688) * remove multilingual sentence piece tokenizer Signed-off-by: Shehzeen Hussain * remove from infer yaml and readme Signed-off-by: Shehzeen Hussain --------- Signed-off-by: Shehzeen Hussain --- .../conf/magpietts/magpietts_inference_multilingual_v1.yaml | 3 --- examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml | 3 --- scripts/magpietts/README_magpie_po.md | 1 - .../L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh | 1 - 4 files changed, 8 deletions(-) diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml index 22d2b73fa3a0..d55f2297b596 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml @@ -121,9 +121,6 @@ model: tone_prefix: "#" ascii_letter_prefix: "" ascii_letter_case: "upper" - multilingual_sentencepiece: - _target_: AutoTokenizer - pretrained_model: "bert-base-multilingual-uncased" test_ds: dataset: diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index df8dd230b724..cdafae4b192f 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -114,9 +114,6 @@ model: tone_prefix: "#" ascii_letter_prefix: "" ascii_letter_case: "upper" - multilingual_sentencepiece: - _target_: AutoTokenizer - pretrained_model: "bert-base-multilingual-uncased" train_ds: dataset: diff --git a/scripts/magpietts/README_magpie_po.md b/scripts/magpietts/README_magpie_po.md index fee1b575c124..4fee74e3b72f 100644 --- a/scripts/magpietts/README_magpie_po.md +++ b/scripts/magpietts/README_magpie_po.md @@ -195,7 +195,6 @@ python examples/tts/magpietts.py \ batch_size=2 \ +init_from_ptl_ckpt="/mountdir/checkpoints/magpie_checkpoints/shared_char_ipa_epoch285.ckpt" \ +mode="onlinepo_train" \ -~model.text_tokenizers.multilingual_sentencepiece \ +model.text_tokenizers.chartokenizer._target_=AutoTokenizer \ +model.text_tokenizers.chartokenizer.pretrained_model="google/byt5-small" \ max_epochs=20 \ diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh index 622d8b978df5..186cd2148e56 100644 --- a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh @@ -14,7 +14,6 @@ coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts.py \ --config-name magpietts_multilingual_v1 \ +mode="onlinepo_train" \ - ~model.text_tokenizers.multilingual_sentencepiece \ +model.text_tokenizers.english_chartokenizer._target_=AutoTokenizer \ +model.text_tokenizers.english_chartokenizer.pretrained_model="google/byt5-small" \ +model.text_tokenizers.spanish_chartokenizer._target_=AutoTokenizer \ From 87b946f2d8ed4799bc8e1845478d3a265f4296ba Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Tue, 16 Sep 2025 10:02:50 -0700 Subject: [PATCH 086/113] [unittests] fixed bugs for sampling in cer, speaker similarity and validation status. (#14738) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- nemo/collections/common/data/lhotse/sampling.py | 12 ++++++++---- tests/collections/common/test_lhotse_dataloading.py | 4 ---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index e17dd2bb3a2c..0c927058c234 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -18,7 +18,7 @@ from typing import Any, Sequence import numpy as np -from lhotse.cut import Cut +from lhotse.cut import Cut, MonoCut from lhotse.dataset import SamplingConstraint, TokenConstraint from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint from lhotse.utils import ifnone @@ -279,7 +279,7 @@ def __init__(self, keep: str = "pass") -> None: def __call__(self, example) -> bool: if ( - isinstance(example, Cut) + isinstance(example, MonoCut) and example.has_custom("validation_status") and example.validation_status != self.keep ): @@ -298,7 +298,11 @@ def __init__(self, max_cer: float | None) -> None: self.max_cer = ifnone(max_cer, float("inf")) def __call__(self, example) -> bool: - if isinstance(example, Cut) and len(example.supervisions) > 0 and example.supervisions[0].has_custom("cer"): + if ( + isinstance(example, MonoCut) + and len(example.supervisions) > 0 + and example.supervisions[0].has_custom("cer") + ): return example.supervisions[0].cer <= self.max_cer else: return True @@ -315,7 +319,7 @@ def __init__(self, min_context_speaker_similarity: float | None) -> None: def __call__(self, example) -> bool: if ( - isinstance(example, Cut) + isinstance(example, MonoCut) and len(example.supervisions) > 0 and example.supervisions[0].has_custom("context_speaker_similarity") ): diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index e03a73b3c24c..b709b468ab59 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -345,7 +345,6 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path): # exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts -@pytest.mark.pleasefixme def test_dataloader_from_lhotse_cuts_pad_min_duration(cutset_path: Path): config = OmegaConf.create( { @@ -1940,7 +1939,6 @@ def test_multimodal_text_audio_dataloading_randomized_round_robin_strategy( assert torch.is_tensor(ex.mask) -@pytest.mark.pleasefixme def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): config = OmegaConf.create( { @@ -1970,7 +1968,6 @@ def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: assert -5.0 < cut.tracks[1].snr < 5.0 -@pytest.mark.pleasefixme def test_dataloader_with_noise_lhotse_jsonl(cutset_path: Path): config = OmegaConf.create( { @@ -2000,7 +1997,6 @@ def test_dataloader_with_noise_lhotse_jsonl(cutset_path: Path): assert -5.0 < cut.tracks[1].snr < 5.0 -@pytest.mark.pleasefixme def test_dataloader_with_noise_nemo_tar(cutset_path: Path, nemo_tarred_manifest_path_multi: Path): noise_json, noise_tar = nemo_tarred_manifest_path_multi config = OmegaConf.create( From 7194fe390a3dea9cb500e76c33537d76ba4f912d Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 19 Sep 2025 07:05:28 -0700 Subject: [PATCH 087/113] [magpietts][lhotse] add a script for creating text context manifest for riva speakers and jhh. (#14649) * [magpietts][lhotse] add a script for creating text context manifest for riva speakers and jhh. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * added copyright header Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- .../create_text_context_lhotse_manifest.py | 224 ++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 scripts/magpietts/create_text_context_lhotse_manifest.py diff --git a/scripts/magpietts/create_text_context_lhotse_manifest.py b/scripts/magpietts/create_text_context_lhotse_manifest.py new file mode 100644 index 000000000000..b229d2bfd2b7 --- /dev/null +++ b/scripts/magpietts/create_text_context_lhotse_manifest.py @@ -0,0 +1,224 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Create Text Context Lhotse Manifest Script + +This script converts MagpieTTS Lhotse manifest files from audio-based context to text-based context +for model training. It processes sharded datasets by extracting speaker and suffix information from +supervision IDs and replacing complex audio context metadata with simplified text context strings. + +The script supports three datasets: +- rivaLindyRodney +- rivaEmmaMeganSeanTom +- jhsdGtc20Amp20Keynote + +For each dataset, it: +1. Finds and validates all shard files in the input directory +2. Processes each shard by replacing audio context with text context +3. Saves the modified shards to a new directory with "_textContext" suffix +4. Validates the output by inspecting the last processed cut + +The script expects the input data to be organized as: +``` +./model_release_2505/lhotse_shar/{dataset}/lhotse_shar_shuffle_shardSize256/cuts/ +├── cuts.000000.jsonl.gz +├── cuts.000001.jsonl.gz +└── cuts.000002.jsonl.gz +``` + +Output will be saved to: +``` +./model_release_2505/lhotse_shar/{dataset}/lhotse_shar_shuffle_shardSize256/cuts_textContext/ +├── cuts.000000.jsonl.gz +├── cuts.000001.jsonl.gz +└── cuts.000002.jsonl.gz +``` + +Usage: + python create_text_context_lhotse_manifest.py + + +**BEFORE (Original Audio Context):** +The input cut contains audio context information with references to other audio files: +``` +custom={ + 'emotion': 'happy', + 'context_speaker_similarity': 0.8681941628456116, + 'context_audio_offset': 0.0, + 'context_audio_duration': 4.17, + 'context_audio_text': 'The river bared its bosom, and snorting steamboats challenged the wilderness.', + 'context_recording_id': 'rec-Rodney-44khz-CMU_HAPPY-RODNEY_CMU_HAPPY_000487' +} +``` + +**AFTER (Text Context):** +The output cut contains simplified text context information: +``` +custom={ + 'context_text': 'Speaker and Emotion: | Language:en Dataset:rivaLindyRodney Speaker:Rodney_CMU_HAPPY |', + 'emotion': 'happy' # preserved from original +} +``` +""" + +import glob +import logging +import os +import re +from functools import partial + +from lhotse import CutSet +from rich import print +from tqdm import tqdm + + +def batch_replace_and_write(cut_filepath, new_cut_filepath, dataset_name): + """ + Process a single Lhotse shard file by replacing audio context with text context. + + This function loads a CutSet from a shard file, applies the text context transformation + to each cut, and saves the modified CutSet to a new file. + + Args: + cut_filepath (str): Path to the input shard file (e.g., cuts.000000.jsonl.gz) + new_cut_filepath (str): Path where the modified shard file will be saved + dataset_name (str): Name of the dataset being processed, used to determine + how to parse supervision IDs for speaker information + """ + print(f" Processing {dataset_name}: {cut_filepath} --> {new_cut_filepath}") + cuts = CutSet.from_file(cut_filepath) + cuts_with_validation = cuts.map(partial(replace_audio_context_with_text_context, dataset_name=dataset_name)) + cuts_with_validation.to_file(new_cut_filepath) + + +def replace_audio_context_with_text_context(cut, dataset_name): + """ + Replace audio context information with text context for a single cut. + + This function extracts speaker and speaker suffix information from the supervision ID + and creates a text-based context string. The parsing logic varies by dataset + due to different ID formats. + + Args: + cut: A Lhotse Cut object containing audio and supervision information + dataset_name (str): Name of the dataset, determines parsing logic: + - "rivaLindyRodney": Uses items[4] as speaker suffix + - "rivaEmmaMeganSeanTom": Extracts middle parts of items[4] split by "_" + - "jhsdGtc20Amp20Keynote": Uses items[3] as speaker suffix + + Returns: + cut: The modified Cut object with updated custom context information + + Raises: + ValueError: If dataset_name is not one of the supported datasets + + Example: + For a cut with speaker "Rodney" and supervision ID "sup-rec-Rodney-44khz-CMU_HAPPY-RODNEY_CMU_HAPPY_000452", + this might create context_text: "Speaker and Emotion: | Language:en Dataset:rivaLindyRodney Speaker:Rodney_CMU_HAPPY |" + """ + speaker = cut.supervisions[0].speaker + seg_id = cut.supervisions[0].id + items = seg_id.split("-") + + if dataset_name == "rivaLindyRodney": + speaker_suffix = items[4] + elif dataset_name == "rivaEmmaMeganSeanTom": + speaker_suffix = "_".join(items[4].split("_")[1:-1]) + elif dataset_name == "jhsdGtc20Amp20Keynote": + speaker_suffix = items[3] + else: + raise ValueError(f"Unknown dataset name: {dataset_name}") + + text_context = f"Speaker and Emotion: {speaker.rstrip('| ')}_{speaker_suffix} |" + new_custom = {"context_text": text_context} + + # keep original emotion state if any. + if cut.supervisions[0].has_custom("emotion"): + new_custom.update({"emotion": cut.supervisions[0].emotion}) + + cut.supervisions[0].custom = new_custom + + return cut + + +def find_and_verify_shards(cuts_dir: str): + """ + Find and validate all Lhotse shard files in the specified directory. + + This function searches for shard files matching the pattern "cuts.*.jsonl.gz" + and verifies that the shard indices are contiguous starting from 0. This ensures + that all shards are present and properly numbered for processing. + + Args: + cuts_dir (str): Directory path containing the shard files + + Returns: + list[str]: Sorted list of paths to all shard files + + Raises: + FileNotFoundError: If no shard files are found matching the expected pattern + ValueError: If shard indices are not contiguous or don't start from 0 + + Example: + If cuts_dir contains files: cuts.000000.jsonl.gz, cuts.000001.jsonl.gz, cuts.000002.jsonl.gz + Returns: ['/path/to/cuts.000000.jsonl.gz', '/path/to/cuts.000001.jsonl.gz', '/path/to/cuts.000002.jsonl.gz'] + """ + cuts_shard_pattern = os.path.join(cuts_dir, "cuts.*.jsonl.gz") + all_cuts_shard_paths = sorted(glob.glob(cuts_shard_pattern)) + + if not all_cuts_shard_paths: + msg = f"No input cut shards found matching pattern: {cuts_shard_pattern}. Cannot proceed." + logging.error(msg) + raise FileNotFoundError(msg) + + num_total_shards = len(all_cuts_shard_paths) + + # Verify shard indices are contiguous and start from 0 based on filenames (globally) + first_idx_str = re.search(r"cuts\.(\d+)\.jsonl\.gz$", all_cuts_shard_paths[0]).group(1) + last_idx_str = re.search(r"cuts\.(\d+)\.jsonl\.gz$", all_cuts_shard_paths[-1]).group(1) + first_idx = int(first_idx_str) + last_idx = int(last_idx_str) + expected_last_idx = num_total_shards - 1 + if first_idx != 0: + raise ValueError(f"Expected first shard index to be 0, but found {first_idx} in {all_cuts_shard_paths[0]}") + if last_idx != expected_last_idx: + raise ValueError( + f"Expected last shard index to be {expected_last_idx}, but found {last_idx} in {all_cuts_shard_paths[-1]}" + ) + logging.info( + f"Verified {num_total_shards} total shard files globally, with indices from {first_idx} to {last_idx}." + ) + return all_cuts_shard_paths + + +if __name__ == "__main__": + datasets = ["rivaLindyRodney", "rivaEmmaMeganSeanTom", "jhsdGtc20Amp20Keynote"] + for dataset in datasets: + cut_dir = f"./model_release_2505/lhotse_shar/{dataset}/lhotse_shar_shuffle_shardSize256/cuts" + all_cuts_shard_paths = find_and_verify_shards(cut_dir) + cut_dir_tc = cut_dir + "_textContext" + os.makedirs(cut_dir_tc, exist_ok=True) + + for cut_filepath in tqdm(all_cuts_shard_paths, total=len(all_cuts_shard_paths)): + cut_basename = os.path.basename(cut_filepath) + cut_filepath_tc = os.path.join(cut_dir_tc, cut_basename) + batch_replace_and_write(cut_filepath, cut_filepath_tc, dataset_name=dataset) + + # validate + cuts = CutSet.from_file(cut_filepath_tc) + cuts_list = list() + for cut in cuts: + cuts_list.append(cut) + print(cuts_list[-1]) From d0cfb273ebac881edfc2706d14b7ada5de25ce72 Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Mon, 22 Sep 2025 09:14:16 -0700 Subject: [PATCH 088/113] End detection updates and Text context remapping during training (#14569) * add text context remapping during training Signed-off-by: Paarth Neekhara * EOS detection method made customizable Signed-off-by: Paarth Neekhara * connect with infer eval Signed-off-by: Paarth Neekhara * Apply isort and black reformatting Signed-off-by: paarthneekhara * addressed comments and added doc Signed-off-by: Paarth Neekhara * Apply copolit suggestion and update docstring --------- Signed-off-by: Paarth Neekhara Signed-off-by: paarthneekhara Co-authored-by: paarthneekhara Co-authored-by: Jason --- .../tts/conf/magpietts/magpietts_dc_en.yaml | 3 + examples/tts/conf/magpietts/magpietts_en.yaml | 3 + .../magpietts/magpietts_inference_en.yaml | 3 + .../magpietts_inference_multilingual_v1.yaml | 3 + .../magpietts/magpietts_lhotse_dc_en.yaml | 3 + .../magpietts_lhotse_dc_en_tiny.yaml | 3 + .../magpietts/magpietts_multilingual_v1.yaml | 3 + .../tts/data/text_to_speech_dataset.py | 15 ++- .../tts/data/text_to_speech_dataset_lhotse.py | 13 ++- nemo/collections/tts/models/magpietts.py | 94 +++++++++++++++---- .../tts/modules/magpietts_modules.py | 42 +++++++++ scripts/magpietts/infer_and_evaluate.py | 22 +++++ 12 files changed, 184 insertions(+), 23 deletions(-) diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml index 9703c9d3b58b..76d6fdde052a 100644 --- a/examples/tts/conf/magpietts/magpietts_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_dc_en.yaml @@ -53,6 +53,9 @@ model: local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index cb5842cdf834..5d45bbb23764 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -55,6 +55,9 @@ model: local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_inference_en.yaml index 3c252139a7e5..993868ea6959 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_en.yaml @@ -62,6 +62,9 @@ model: local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml index d55f2297b596..6e8bbf47a7d1 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml @@ -62,6 +62,9 @@ model: local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index 337de93ddba3..b3060c5b8384 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -42,6 +42,9 @@ model: local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml index 0d26dfdb7996..b1dd69b17723 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml @@ -42,6 +42,9 @@ model: local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index cdafae4b192f..0e4cf263d469 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -55,6 +55,9 @@ model: local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 + text_tokenizers: # Add more languages for multi-lingual TTS english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 6e1e20725c89..a64160011606 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -355,6 +355,8 @@ class MagpieTTSDataset(TextToSpeechDataset): pad_context_text_to_max_duration: Whether to pad context text to max context audio frames. context_duration_min: Minimum duration of context audio in seconds. context_duration_max: Maximum duration of context audio in seconds. + text_context_remapping: Dict defining mapping of multiple text contexts to a single text context. + text_context_remapping_prob: Probability of remapping the original text context to a remapped text context. """ def __init__( @@ -383,6 +385,8 @@ def __init__( pad_context_text_to_max_duration: bool = False, context_duration_min: float = 3.0, context_duration_max: float = 10.0, + text_context_remapping: Dict[str, str] = None, + text_context_remapping_prob: float = 0.0, ): super().__init__( dataset_meta=dataset_meta, @@ -417,6 +421,8 @@ def __init__( self.pad_context_text_to_max_duration = pad_context_text_to_max_duration self.context_duration_min = context_duration_min self.context_duration_max = context_duration_max + self.text_context_remapping = text_context_remapping + self.text_context_remapping_prob = text_context_remapping_prob def get_num_audio_samples_to_slice(self, duration, sample_rate): num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) @@ -578,9 +584,12 @@ def __getitem__(self, index): if self.use_text_conditioning_tokenizer: if 'context_text' in data.manifest_entry: - context_tokens = self.text_tokenizer.encode( - data.manifest_entry['context_text'], self.text_conditioning_tokenizer_name - ) + context_text = data.manifest_entry['context_text'] + if self.text_context_remapping is not None and context_text in self.text_context_remapping: + if self.dataset_type == 'train' and random.random() < self.text_context_remapping_prob: + # Only remap during training. Give the exact text context during inference. + context_text = self.text_context_remapping[context_text] + context_tokens = self.text_tokenizer.encode(context_text, self.text_conditioning_tokenizer_name) example['has_text_context'] = True else: context_tokens = self.text_tokenizer.encode("[NO TEXT CONTEXT]", self.text_conditioning_tokenizer_name) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index d427c17fefb4..6090c4b164f0 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -126,6 +126,8 @@ class MagpieTTSLhotseDataset(torch.utils.data.Dataset): tokenizer_config (Optional[DictConfig]): Configuration for the text tokenizers. Used for lazy initialization within workers. Must be provided if tokenizers are not set externally. Defaults to None. + text_context_remapping: Dict defining mapping of multiple text contexts to a single text context. + text_context_remapping_prob: Probability of remapping the original text context to a remapped text context. """ def __init__( @@ -148,6 +150,8 @@ def __init__( use_text_conditioning_tokenizer: bool = False, text_conditioning_tokenizer_name: str = None, tokenizer_config: DictConfig = None, + text_context_remapping: Dict[str, str] = None, + text_context_remapping_prob: float = 0.0, ): super().__init__() self.sample_rate = sample_rate @@ -172,6 +176,8 @@ def __init__( self.context_duration_max = context_duration_max self.tokenizer_config = tokenizer_config self.text_tokenizer = None + self.text_context_remapping = text_context_remapping + self.text_context_remapping_prob = text_context_remapping_prob def get_num_audio_samples_to_slice(self, duration, sample_rate): num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) @@ -369,8 +375,13 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: if self.use_text_conditioning_tokenizer: if cut.supervisions[0].has_custom("context_text"): + context_text = cut.supervisions[0].context_text + if self.text_context_remapping is not None and context_text in self.text_context_remapping: + if self.dataset_type == 'train' and random.random() < self.text_context_remapping_prob: + # Only remap during training. Give the exact text context during inference. + context_text = self.text_context_remapping[context_text] context_text_tokens = self.text_tokenizer.encode( - cut.supervisions[0].context_text, tokenizer_name=self.text_conditioning_tokenizer_name + context_text, tokenizer_name=self.text_conditioning_tokenizer_name ) has_text_context = True else: diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index fb9a060e6561..973e75fb9f33 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os import random import time from functools import partial -from typing import List, Optional +from typing import List, Union import numpy as np import soundfile as sf @@ -37,6 +38,7 @@ from nemo.collections.tts.modules.aligner import AlignmentEncoder from nemo.collections.tts.modules.magpietts_modules import ( CharAwareSubwordEncoder, + EOSDetectionMethod, LocalTransformerType, SpecialAudioToken, cosine_schedule, @@ -176,6 +178,20 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.pad_context_text_to_max_duration = self.model_type in ['decoder_context_tts', 'decoder_ce'] self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False) + # Below args (text_context_remapping_json, text_context_remapping_prob) are + # for combining multiple context_texts into a single one during training. + # Eg. if we want to treat Emma_neutral and Emma_conversational as one speaker, + # we can create an override dict {'Emma_neutral' : 'Emma', 'Emma_conversational' : 'Emma'} + # This dict is saved in a json file given by cfg.model.text_context_remapping_json + # If we want to preserve both behaviours i.e (Emma_neutral, Emma_conversational) and just (Emma) + # we can do this mapping with a probability during training, as specified by text_context_remapping_prob + self.text_context_remapping = None + text_context_remapping_json = cfg.get('text_context_remapping_json', None) + self.text_context_remapping_prob = cfg.get('text_context_remapping_prob', 0.0) + if text_context_remapping_json is not None: + with open(text_context_remapping_json, 'r') as f: + self.text_context_remapping = json.load(f) + super().__init__(cfg=cfg, trainer=trainer) if self.legacy_text_conditioning: @@ -1997,22 +2013,57 @@ def get_inference_attention_plots( return cross_attention_maps, headwise_cross_attention_maps - def find_eos_frame_index(self, codes) -> Optional[int]: + def find_eos_frame_index(self, codes, eos_detection_method) -> Union[int, float]: """ Checks for EOS in the predicted codes. Returns the index of the first frame within the frame stack that contains an EOS token across any codebook, or `None` if no EOS is found. Args: codes: (num_codebooks, frame_stacking_factor) Returns: - index (within the frame stack) of the first frame with EOS, or `None` if no EOS is found + index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found """ eos_mask = codes == self.audio_eos_id # (codebooks, frame_stacking_factor) - eos_per_frame = eos_mask.any(dim=0) # (frame_stacking_factor,) - True if any codebook has EOS in this frame + detection_type = EOSDetectionMethod.detection_type(eos_detection_method) + if detection_type == "any": + eos_per_frame = eos_mask.any( + dim=0 + ) # (frame_stacking_factor,) - True if any codebook has EOS in this frame + elif detection_type == "all": + eos_per_frame = eos_mask.all( + dim=0 + ) # (frame_stacking_factor,) - True if all codebooks have EOS in this frame + elif detection_type == "zero_cb": + eos_per_frame = eos_mask[:1, :].any( + dim=0 + ) # (frame_stacking_factor,) - True if zeroth codebook has EOS in this frame + else: + raise ValueError(f"Invalid EOS detection method: {eos_detection_method}") # find first frame with EOS if eos_per_frame.any(): # return index of the first frame with EOS return eos_per_frame.nonzero()[0].item() - return None + return float('inf') + + def detect_eos(self, audio_codes_multinomial, audio_codes_argmax, eos_detection_method) -> Union[int, float]: + """ + Detects EOS in the predicted codes. Returns the index of the first frame within the frame stack + that triggers EOS detection, or `float('inf')` if no EOS is found. + Args: + audio_codes_multinomial: (num_codebooks, frame_stacking_factor) - Multinomial samples + audio_codes_argmax: (num_codebooks, frame_stacking_factor) - Argmax samples + eos_detection_method: EOS detection method + Returns: + index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found + """ + sampling_type = EOSDetectionMethod.sampling_type(eos_detection_method) + if sampling_type == "argmax": + return self.find_eos_frame_index(audio_codes_argmax, eos_detection_method) + elif sampling_type == "argmax_or_multinomial": + argmax_eos_frame = self.find_eos_frame_index(audio_codes_argmax, eos_detection_method) + multinomial_eos_frame = self.find_eos_frame_index(audio_codes_multinomial, eos_detection_method) + return min(argmax_eos_frame, multinomial_eos_frame) + else: + raise ValueError(f"Invalid EOS detection method: {eos_detection_method}") def infer_batch( self, @@ -2037,7 +2088,10 @@ def infer_batch( maskgit_fixed_schedule=None, maskgit_dynamic_cfg_scale=False, maskgit_sampling_type=None, + ignore_finished_sentence_tracking=False, + eos_detection_method="argmax_or_multinomial_any", ): + eos_detection_method = EOSDetectionMethod(eos_detection_method) with torch.no_grad(): start_time = time.time() self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference) @@ -2194,10 +2248,14 @@ def infer_batch( batch_size=batch_size, ) - finished_items = { - k: v for k, v in finished_texts_counter.items() if v >= 20 - } # Items that have been close to the end for atleast 20 timesteps - unfinished_items = {k: v for k, v in unfinished_texts.items() if v} + if ignore_finished_sentence_tracking: + finished_items = {} + unfinished_items = {} + else: + finished_items = { + k: v for k, v in finished_texts_counter.items() if v >= 20 + } # Items that have been close to the end for atleast 20 timesteps + unfinished_items = {k: v for k, v in unfinished_texts.items() if v} all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) if use_local_transformer_for_inference: @@ -2250,17 +2308,11 @@ def infer_batch( for item_idx in range(all_codes_next_argmax.size(0)): if item_idx not in end_indices: - # check for EOS (including within the frame stack) - eos_frame_multinomial = self.find_eos_frame_index(audio_codes_next[item_idx]) - eos_frame_argmax = self.find_eos_frame_index(all_codes_next_argmax[item_idx]) - eos_frame_multinomial = ( - eos_frame_multinomial if eos_frame_multinomial is not None else float('inf') + end_frame_index = self.detect_eos( + audio_codes_next[item_idx], all_codes_next_argmax[item_idx], eos_detection_method ) - eos_frame_argmax = eos_frame_argmax if eos_frame_argmax is not None else float('inf') - # pick minimum of the two - frame_index = min(eos_frame_multinomial, eos_frame_argmax) - if frame_index != float('inf'): - global_index = idx * self.frame_stacking_factor + frame_index + if end_frame_index != float('inf'): + global_index = idx * self.frame_stacking_factor + end_frame_index end_indices[item_idx] = global_index print(f"End detected for item {item_idx} at decoder timestep: {idx}") @@ -2421,6 +2473,8 @@ def get_dataset(self, dataset_cfg, dataset_type): pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, context_duration_min=self.cfg.context_duration_min, context_duration_max=self.cfg.context_duration_max, + text_context_remapping=self.text_context_remapping, + text_context_remapping_prob=self.text_context_remapping_prob, ) dataset.load_16khz_audio = self.model_type == 'single_encoder_sv_tts' dataset.tokenizer_config = ( @@ -2450,6 +2504,8 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder, text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, tokenizer_config=self.cfg.text_tokenizers, + text_context_remapping=self.text_context_remapping, + text_context_remapping_prob=self.text_context_remapping_prob, ) data_loader = get_lhotse_dataloader_from_config( config=dataset_cfg.dataset, diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index a731a5ea97ed..e1c6786ed776 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -36,6 +36,48 @@ class LocalTransformerType(PrettyStrEnum): MASKGIT = "maskgit" +class EOSDetectionMethod(PrettyStrEnum): + """ + Enum for the EOS detection method to use in the MagpieTTS model. + These strings are the values allowed in the YAML config file. + """ + + ARGMAX_ANY = "argmax_any" + ARGMAX_OR_MULTINOMIAL_ANY = "argmax_or_multinomial_any" + ARGMAX_ALL = "argmax_all" + ARGMAX_OR_MULTINOMIAL_ALL = "argmax_or_multinomial_all" + ARGMAX_ZERO_CB = "argmax_zero_cb" + ARGMAX_OR_MULTINOMIAL_ZERO_CB = "argmax_or_multinomial_zero_cb" + + @staticmethod + def detection_type(detection_method: EOSDetectionMethod): + if detection_method in [EOSDetectionMethod.ARGMAX_ANY, EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ANY]: + return "any" + elif detection_method in [EOSDetectionMethod.ARGMAX_ALL, EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ALL]: + return "all" + elif detection_method in [EOSDetectionMethod.ARGMAX_ZERO_CB, EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ZERO_CB]: + return "zero_cb" + else: + raise ValueError(f"Invalid EOS detection method: {detection_method}") + + @staticmethod + def sampling_type(detection_method: EOSDetectionMethod): + if detection_method in [ + EOSDetectionMethod.ARGMAX_ANY, + EOSDetectionMethod.ARGMAX_ALL, + EOSDetectionMethod.ARGMAX_ZERO_CB, + ]: + return "argmax" + elif detection_method in [ + EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ANY, + EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ALL, + EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ZERO_CB, + ]: + return "argmax_or_multinomial" + else: + raise ValueError(f"Invalid EOS detection method: {detection_method}") + + class SpecialAudioToken(Enum): """ Enum for the special tokens to use in the MagpieTTS model. diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index b84f8e1a0a71..a22c178bcec4 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -293,6 +293,8 @@ def run_inference( log_exp_name=False, compute_fcd=False, violin_plot_metrics=['cer', 'pred_context_ssim'], + eos_detection_method=None, + ignore_finished_sentence_tracking=False, ): # Load model if hparams_file is not None and checkpoint_file is not None: @@ -359,6 +361,8 @@ def run_inference( f"LT_{use_local_transformer}_" f"MaskGit_{maskgit_n_steps}_{maskgit_sampling_type}_{''.join([str(l) for l in maskgit_fixed_schedule]) if maskgit_fixed_schedule is not None else 'None'}_" f"SV_{sv_model}" + f"EOS_{eos_detection_method}" + f"IgnoreFST_{ignore_finished_sentence_tracking}" ) dataset_meta_info = evalset_config.dataset_meta_info @@ -488,6 +492,8 @@ def run_inference( maskgit_noise_scale=maskgit_noise_scale, maskgit_fixed_schedule=maskgit_fixed_schedule, maskgit_sampling_type=maskgit_sampling_type, + ignore_finished_sentence_tracking=ignore_finished_sentence_tracking, + eos_detection_method=eos_detection_method, ) all_rtf_metrics.append(rtf_metrics) @@ -650,6 +656,19 @@ def main(): parser.add_argument('--start_prior_after_n_audio_steps', type=int, default=0) parser.add_argument('--topk', type=int, default=80) parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument( + '--eos_detection_method', + type=str, + default="argmax_or_multinomial_any", + choices=[ + "argmax_any", + "argmax_or_multinomial_any", + "argmax_all", + "argmax_or_multinomial_all", + "argmax_zero_cb", + "argmax_or_multinomial_zero_cb", + ], + ) # Parameters for evaluation parser.add_argument('--sv_model', type=str, default="titanet") # titanet, wavlm parser.add_argument( @@ -659,6 +678,7 @@ def main(): parser.add_argument('--confidence_level', type=float, default=0.95) parser.add_argument('--legacy_codebooks', action='store_true') parser.add_argument('--legacy_text_conditioning', action='store_true') + parser.add_argument('--ignore_finished_sentence_tracking', action='store_true') parser.add_argument('--clean_up_disk', action='store_true') parser.add_argument('--cer_target', type=float, default=None) parser.add_argument('--ssim_target', type=float, default=None) @@ -722,6 +742,8 @@ def main(): log_exp_name=args.log_exp_name, compute_fcd=compute_fcd, violin_plot_metrics=args.violin_plot_metrics, + eos_detection_method=args.eos_detection_method, + ignore_finished_sentence_tracking=args.ignore_finished_sentence_tracking, ) # Mode 1: Run inference from provided hparams and checkpoint files From d7fee58387fc8fae892f95b1196671b8e7692c1c Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 26 Sep 2025 20:47:00 -0400 Subject: [PATCH 089/113] Checkout magpie branch instead of main (#14813) * update checkout in cicd Signed-off-by: Jason * use our template Signed-off-by: Jason * change ref to be local rather than global Signed-off-by: Jason * change ref to be local rather than global Signed-off-by: Jason * revert to defaul ref of merge commit Signed-off-by: Jason * undo container change Signed-off-by: Jason * undo container change Signed-off-by: Jason * another attempt for local action Signed-off-by: Jason * another attempt for local action Signed-off-by: Jason * another attempt for local action Signed-off-by: Jason * disable L2 tests Signed-off-by: Jason --------- Signed-off-by: Jason --- .github/actions/test-template/action.yml | 11 +++-- .github/workflows/cicd-main-automodel.yml | 4 +- .github/workflows/cicd-main-nemo2.yml | 4 +- .github/workflows/cicd-main-speech.yml | 8 +--- .github/workflows/cicd-main-unit-tests.yml | 28 ++++--------- .github/workflows/cicd-main.yml | 49 ++++++++++------------ 6 files changed, 39 insertions(+), 65 deletions(-) diff --git a/.github/actions/test-template/action.yml b/.github/actions/test-template/action.yml index d87a5a447dbc..e2ac1852dec6 100644 --- a/.github/actions/test-template/action.yml +++ b/.github/actions/test-template/action.yml @@ -52,11 +52,11 @@ inputs: runs: using: "composite" steps: - - name: Noop - shell: bash - run: | - chmod -R u+rwX ${{ github.run_id }} - echo "noop" + # - name: Noop + # shell: bash + # run: | + # chmod -R u+rwX ${{ github.run_id }} + # echo "noop" - name: Docker system cleanup shell: bash @@ -81,7 +81,6 @@ runs: - name: Checkout NeMo uses: actions/checkout@v5 with: - ref: ${{ github.event.pull_request.head.ref }} path: ${{ github.run_id }}/${{steps.uuid.outputs.id }}/NeMo - name: Start container diff --git a/.github/workflows/cicd-main-automodel.yml b/.github/workflows/cicd-main-automodel.yml index 89b8addd0261..2eaebe9b2b27 100644 --- a/.github/workflows/cicd-main-automodel.yml +++ b/.github/workflows/cicd-main-automodel.yml @@ -46,10 +46,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} diff --git a/.github/workflows/cicd-main-nemo2.yml b/.github/workflows/cicd-main-nemo2.yml index 092ac96c655f..a056410520e7 100644 --- a/.github/workflows/cicd-main-nemo2.yml +++ b/.github/workflows/cicd-main-nemo2.yml @@ -289,10 +289,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} diff --git a/.github/workflows/cicd-main-speech.yml b/.github/workflows/cicd-main-speech.yml index 6d274271647f..ae1d70f56315 100644 --- a/.github/workflows/cicd-main-speech.yml +++ b/.github/workflows/cicd-main-speech.yml @@ -65,10 +65,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -196,10 +194,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} diff --git a/.github/workflows/cicd-main-unit-tests.yml b/.github/workflows/cicd-main-unit-tests.yml index 2db945f54ba5..ebe596fa90c5 100644 --- a/.github/workflows/cicd-main-unit-tests.yml +++ b/.github/workflows/cicd-main-unit-tests.yml @@ -35,10 +35,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -61,10 +59,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -88,10 +84,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -114,10 +108,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -146,10 +138,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -172,10 +162,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -199,10 +187,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: ${{ matrix.script }} diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index ce61beb1456c..a1f4ebb9ac35 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -183,9 +183,8 @@ jobs: echo "id=$(uuidgen)" >> "$GITHUB_OUTPUT" - name: Checkout NeMo - uses: actions/checkout@v2 + uses: actions/checkout@v5 with: - repository: NVIDIA/NeMo path: ${{ github.run_id }}/${{steps.uuid.outputs.id }}/NeMo - name: Run some checks @@ -227,11 +226,9 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - name: main - uses: NVIDIA/NeMo/.github/actions/test-template@main + uses: ./.github/actions/test-template with: runner: ${{ runner.name }} script: L0_Setup_Test_Data_And_Models @@ -296,27 +293,27 @@ jobs: with: test_to_run: ${{ needs.pre-flight.outputs.test_to_run }} - cicd-main-nemo2: - needs: [pre-flight, cicd-test-container-build, cicd-main-unit-tests] - uses: ./.github/workflows/cicd-main-nemo2.yml - if: | - ( - needs.pre-flight.outputs.test_to_run != '[]' - && ( - contains(fromJson(needs.pre-flight.outputs.components_to_run), 'nemo2') - || needs.pre-flight.outputs.components_to_run == '["all"]' - ) - ) - && ( - success() - || ( - needs.cicd-wait-in-queue.result == 'skipped' - && needs.pre-flight.outputs.is_ci_workload == 'true' - ) - ) - && !cancelled() - with: - test_to_run: ${{ needs.pre-flight.outputs.test_to_run }} + # cicd-main-nemo2: disabled for magpie dev banch - no L0s and skipping L2s + # needs: [pre-flight, cicd-test-container-build, cicd-main-unit-tests] + # uses: ./.github/workflows/cicd-main-nemo2.yml + # if: | + # ( + # needs.pre-flight.outputs.test_to_run != '[]' + # && ( + # contains(fromJson(needs.pre-flight.outputs.components_to_run), 'nemo2') + # || needs.pre-flight.outputs.components_to_run == '["all"]' + # ) + # ) + # && ( + # success() + # || ( + # needs.cicd-wait-in-queue.result == 'skipped' + # && needs.pre-flight.outputs.is_ci_workload == 'true' + # ) + # ) + # && !cancelled() + # with: + # test_to_run: ${{ needs.pre-flight.outputs.test_to_run }} Nemo_CICD_Test: needs: From f3878d7b86e3f129f211d9c19d66e482a45f6fde Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Sat, 27 Sep 2025 04:10:39 -0700 Subject: [PATCH 090/113] Don't allow EOS until 4 frames have been generated (#14761) * Don't allow EOS until 4 frames have been generated The number of frames is configurable via a parameter to infer_batch(). This is a workaround to the observation that when CFG is on generation sometimes terminates after zero tokens. That appears to be an artifact of CFG, since the EOS logits (both cond and uncon) are not particularly large before CFG but become the maximum logit after it, with CFG amplifying the differences between the two. Signed-off-by: Fejgin, Roy * Formatting Signed-off-by: Fejgin, Roy * Command line option to set minimum number of frames to generate Signed-off-by: Fejgin, Roy * formatting Signed-off-by: Fejgin, Roy * Show extreme values in violin plots (to aid in debugging rare issues) Signed-off-by: Fejgin, Roy * Fix merge issues Signed-off-by: Fejgin, Roy * More merge fixes Signed-off-by: Fejgin, Roy * Remove temporary changes in infer_and_evaluate.py Signed-off-by: Fejgin, Roy * Comments Signed-off-by: Fejgin, Roy * Comments Signed-off-by: Fejgin, Roy * Fix typo Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 37 ++++++++++++++++++++---- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 973e75fb9f33..c719d5f9b1ab 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -734,7 +734,7 @@ def code_to_str(code): output_str += c logging.debug(output_str) - def clear_forbidden_logits(self, logits): + def clear_forbidden_logits(self, logits, forbid_audio_eos=False): """ Sets logits of forbidden tokens to `-inf` so they will never be sampled. Specifically, we forbid sampling of all special tokens except AUDIO_EOS. @@ -742,7 +742,9 @@ def clear_forbidden_logits(self, logits): logits: (B, C, num_audio_tokens_per_codebook) """ logits[ - :, :, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=False) + :, + :, + SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos), ] = float('-inf') return logits @@ -760,6 +762,7 @@ def local_transformer_sample_maskgit( fixed_schedule=None, dynamic_cfg_scale=False, sampling_type=None, + forbid_audio_eos=False, ): """ Sample codes for one timestep from the local transformer using MaskGit. @@ -868,7 +871,7 @@ def local_transformer_sample_maskgit( logits[:actual_batch_size] = cfg_logits # Disallow generation of special tokens (except audio EOS which is handled separately) - logits = self.clear_forbidden_logits(logits) + logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos) # handle unfinished and finished items for item_idx in unfinished_items: @@ -937,6 +940,7 @@ def local_transformer_sample_autoregressive( use_cfg=False, cfg_scale=1.0, use_kv_cache=True, + forbid_audio_eos=False, ): # dec_output: (B, E) self.local_transformer.reset_cache(use_cache=use_kv_cache) @@ -964,7 +968,9 @@ def local_transformer_sample_autoregressive( codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 - codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1) + codebook_logits = self.clear_forbidden_logits( + codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos + ).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -998,7 +1004,13 @@ def local_transformer_sample_autoregressive( return all_preds def sample_codes_from_logits( - self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={} + self, + all_code_logits_t, + temperature=0.7, + topk=80, + unfinished_items={}, + finished_items={}, + forbid_audio_eos=False, ): # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep all_preds = [[] for _ in range(self.frame_stacking_factor)] @@ -1013,7 +1025,9 @@ def sample_codes_from_logits( for item_idx in finished_items: codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 - codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1) + codebook_logits = self.clear_forbidden_logits( + codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos + ).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -2090,6 +2104,9 @@ def infer_batch( maskgit_sampling_type=None, ignore_finished_sentence_tracking=False, eos_detection_method="argmax_or_multinomial_any", + # Setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4 + # lines up with the codec's minimum frame requirement. + min_generated_frames=4, ): eos_detection_method = EOSDetectionMethod(eos_detection_method) with torch.no_grad(): @@ -2257,6 +2274,10 @@ def infer_batch( } # Items that have been close to the end for atleast 20 timesteps unfinished_items = {k: v for k, v in unfinished_texts.items() if v} + # Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor) + # This guards against rare cases of termination right at the start of generation. + forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames + all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) if use_local_transformer_for_inference: if self.local_transformer_type == LocalTransformerType.AR: @@ -2270,6 +2291,7 @@ def infer_batch( use_cfg=use_cfg, cfg_scale=cfg_scale, use_kv_cache=use_LT_kv_cache, + forbid_audio_eos=forbid_audio_eos, ) elif self.local_transformer_type == LocalTransformerType.MASKGIT: audio_codes_next = self.local_transformer_sample_maskgit( @@ -2285,6 +2307,7 @@ def infer_batch( fixed_schedule=maskgit_fixed_schedule, dynamic_cfg_scale=maskgit_dynamic_cfg_scale, sampling_type=maskgit_sampling_type, + forbid_audio_eos=forbid_audio_eos, ) else: raise ValueError( @@ -2298,12 +2321,14 @@ def infer_batch( topk=topk, unfinished_items=unfinished_items, finished_items=finished_items, + forbid_audio_eos=forbid_audio_eos, ) # (B, num_codebooks, frame_stacking_factor) all_codes_next_argmax = self.sample_codes_from_logits( all_code_logits_t, temperature=0.01, unfinished_items=unfinished_items, finished_items=finished_items, + forbid_audio_eos=forbid_audio_eos, ) # (B, num_codebooks, frame_stacking_factor) for item_idx in range(all_codes_next_argmax.size(0)): From 55d5a0119f5c1072017d24ac226ab527842e4c78 Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Tue, 30 Sep 2025 14:14:59 -0700 Subject: [PATCH 091/113] Magpietts 2508 Attention mask bug fix (#14836) * bug fix in attention mask Signed-off-by: Paarth Neekhara * Apply isort and black reformatting Signed-off-by: paarthneekhara * handle None as well Signed-off-by: Paarth Neekhara * Apply isort and black reformatting Signed-off-by: paarthneekhara * Added tests and handled masking in convolutional layer Signed-off-by: Paarth Neekhara * Apply isort and black reformatting Signed-off-by: paarthneekhara --------- Signed-off-by: Paarth Neekhara Signed-off-by: paarthneekhara Co-authored-by: paarthneekhara --- .../tts/modules/transformer_2501.py | 24 +++++-- .../tts/modules/test_transformer_2501.py | 68 +++++++++++++++++-- 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index e730ebfe12f1..ea5a08a0a696 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -82,11 +82,15 @@ def __init__( bias=bias, ) - def forward(self, signal): + def forward(self, signal, signal_mask): + # signal: (B, C, T) + # signal_mask: (B, T) + signal = signal * signal_mask.unsqueeze(1) if self.is_causal: # TODO: maybe replace with identify rather than keep conditional if in forward signal = F.pad(signal, self.causal_padding) conv_signal = self.conv(signal) + conv_signal = conv_signal * signal_mask.unsqueeze(1) return conv_signal @@ -126,12 +130,13 @@ def __init__( self.o_net = ConvolutionLayer(d_ffn, d_model, bias=bias, kernel_size=kernel_size, is_causal=is_causal) self.dropout = torch.nn.Dropout(p_dropout) - def forward(self, x): + def forward(self, x, x_mask): """ x (B, T, C) + x_mask (B, T) """ - x = self.non_linearity(self.proj(x.transpose(1, 2))) - x = self.dropout(self.o_net(x).transpose(1, 2)) + x = self.non_linearity(self.proj(x.transpose(1, 2), x_mask)) + x = self.dropout(self.o_net(x, x_mask).transpose(1, 2)) return x @@ -350,7 +355,14 @@ def compute_qkv_and_mask( v = torch.cat([self.cache['self_v'], v], dim=1) self.cache['self_k'] = k self.cache['self_v'] = v - mask = query_mask[:, None, :, None] if query_mask is not None else None + + mask = None + if query_mask is not None: + # query_mask is a boolean mask of shape (B, T) + # mask should be of shape (B, 1, T, T) where mask[:,0,i,:] == mask[:,0,:,i] == query_mask + mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2) + mask = mask.unsqueeze(1) + return q, k, v, mask @@ -551,7 +563,7 @@ def forward( x = x + x_res # mlp final projection - x = x + self.pos_ff(self.norm_pos_ff(x)) + x = x + self.pos_ff(self.norm_pos_ff(x), x_mask) x = x * x_mask.unsqueeze(-1) return { diff --git a/tests/collections/tts/modules/test_transformer_2501.py b/tests/collections/tts/modules/test_transformer_2501.py index b7f486028aea..606ce12bf324 100644 --- a/tests/collections/tts/modules/test_transformer_2501.py +++ b/tests/collections/tts/modules/test_transformer_2501.py @@ -37,6 +37,9 @@ def set_seed(seed): random.seed(seed) +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths + + @pytest.mark.unit class TestConvolutionLayer: @classmethod @@ -53,6 +56,7 @@ def setup_class(cls): [-1.0317, 1.6818, 1.4257, -0.5003, -1.7254, 0.8830, -0.4541, -0.4631, -0.0986, 0.5083], [-0.3231, -1.0899, 0.5774, 0.1661, 0.9620, -2.3307, -0.6158, -0.3663, 1.2469, -1.0208]]] ) + cls.input_mask = torch.ones(1, cls.input_tensor.shape[2]) # fmt:on def test_non_causal_forward(self): @@ -68,7 +72,7 @@ def test_non_causal_forward(self): ) with torch.no_grad(): - output_tensor = layer(self.input_tensor) + output_tensor = layer(self.input_tensor, self.input_mask) # fmt:off expected_output_tensor = torch.Tensor( @@ -96,7 +100,7 @@ def test_causal_forward(self): ) with torch.no_grad(): - output_tensor = layer(self.input_tensor) + output_tensor = layer(self.input_tensor, self.input_mask) # fmt:off expected_output_tensor = torch.Tensor( @@ -133,6 +137,7 @@ def setup_class(cls): [-0.1543, 0.3365, 1.7475], [-0.1753, 0.4115, 0.0772]]] ) + cls.input_mask = torch.ones(1, cls.input_tensor.shape[1]) # fmt:on def test_causal_forward(self): @@ -142,7 +147,7 @@ def test_causal_forward(self): ) with torch.no_grad(): - output_tensor = layer(self.input_tensor) + output_tensor = layer(self.input_tensor, self.input_mask) # fmt:off expected_output_tensor = torch.Tensor( @@ -168,7 +173,7 @@ def test_non_causal_forward(self): ) with torch.no_grad(): - output_tensor = layer(self.input_tensor) + output_tensor = layer(self.input_tensor, self.input_mask) # fmt:off expected_output_tensor = torch.Tensor( @@ -795,3 +800,58 @@ def test_forward_causal_self_attn_and_has_xattn(self): expected_output["attn_probabilities"][i]["cross_attn_probabilities"][0], atol=1e-4, ) + + +@pytest.mark.unit +class TestTransformerBatchedInference: + @classmethod + def setup_class(cls): + cls.n_layers = 3 + cls.d_model = 4 + cls.d_ffn = 16 + cls.sa_n_heads = 2 + cls.p_dropout = 0.0 + cls.p_dropout_out = 0.0 + cls.max_length_causal_mask = 10 + cls.short_length = 4 + cls.long_length = 10 + + def test_forward(self): + set_seed(0) + query_tensor1 = torch.randn(1, self.long_length, self.d_model) + query_tensor2 = torch.randn(1, self.short_length, self.d_model) + + padding_tensor = torch.randn(1, self.long_length - self.short_length, self.d_model) + query_tensor2_padded = torch.cat([query_tensor2, padding_tensor], dim=1) + lengths = torch.tensor([self.long_length, self.short_length]) + mask_batched = get_mask_from_lengths(lengths) + + query_batched = torch.cat([query_tensor1, query_tensor2_padded], dim=0) + + mask_bs1_1 = torch.ones(1, self.long_length) + mask_bs1_2 = torch.ones(1, self.short_length) + + for is_causal in [True, False]: + for kernel_size in [1, 3]: + model = Transformer( + n_layers=self.n_layers, + d_model=self.d_model, + d_ffn=self.d_ffn, + sa_n_heads=self.sa_n_heads, + kernel_size=kernel_size, + p_dropout=self.p_dropout, + p_dropout_out=self.p_dropout_out, + is_causal=is_causal, + max_length_causal_mask=self.max_length_causal_mask, + ) + + output_batched = model(query_batched, mask_batched) + output_bs1_1 = model(query_tensor1, mask_bs1_1) + output_bs1_2 = model(query_tensor2, mask_bs1_2) + + assert torch.allclose( + output_batched['output'][0][: self.long_length, :], output_bs1_1['output'], atol=1e-4 + ) + assert torch.allclose( + output_batched['output'][1][: self.short_length, :], output_bs1_2['output'], atol=1e-4 + ) From 6b5d25b7558d8633727b418b0fa4ee2721c85e31 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Wed, 1 Oct 2025 10:28:36 -0700 Subject: [PATCH 092/113] [magpie][context audio] add speaker items limit to compute similarity matrix to avoid OOM. (#14833) * [step2] add speaker items limit to compute similarity matrix to avoid OOM of CPU ram. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * defined a constant MASKED_SIMILAIRY_VALUE to cover the magic number -2.0 for better maintainability. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- ...extend_nemo_manifest_with_context_audio.py | 100 ++++++++++++++---- 1 file changed, 79 insertions(+), 21 deletions(-) diff --git a/scripts/magpietts/extend_nemo_manifest_with_context_audio.py b/scripts/magpietts/extend_nemo_manifest_with_context_audio.py index 426fcac99d16..98c59eb379fd 100644 --- a/scripts/magpietts/extend_nemo_manifest_with_context_audio.py +++ b/scripts/magpietts/extend_nemo_manifest_with_context_audio.py @@ -16,6 +16,7 @@ import json import logging import os +import random import re import time from collections import defaultdict @@ -35,6 +36,10 @@ logger = logging.getLogger(__name__) +# Constant for masking identical items in similarity matrix +# Set below valid cosine similarity range [-1, 1] to ensure masked items are never selected +MASKED_SIMILARITY_VALUE = -2.0 + """ Usage: python scripts/magpietts/extend_manifest_with_context_audio.py @@ -48,10 +53,16 @@ --num-workers 4 --context-min-duration 3.0 --context-min-ssim 0.6 + --max-speaker-items 20000 # optional, prevents OOM for large speakers This script distributes speakers across DDP ranks. Each rank processes its assigned speakers and writes a partial manifest. Rank 0 then merges these into a final manifest. +The --max-speaker-items parameter limits the size of the context pool per speaker to prevent OOM +when computing similarity matrices. If a speaker has more items than this limit, a random +sample will be used as the context pool, but all items will still be processed to find +their best context from this pool. + Input manifest example entry: { "audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav", @@ -183,6 +194,7 @@ def __init__( context_min_ssim: float, speaker_expected_counts_map: dict, initial_assigned_count: int, + max_speaker_items: int = None, ): super().__init__() self.sv_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( @@ -197,6 +209,7 @@ def __init__( self.context_min_ssim = context_min_ssim self.speaker_expected_counts = speaker_expected_counts_map self.initial_assigned_count = initial_assigned_count + self.max_speaker_items = max_speaker_items # Per-rank attributes self.output_file_path = None @@ -216,6 +229,10 @@ def setup(self, stage: str): self.output_dir.mkdir(parents=True, exist_ok=True) self.output_manifest_file = open(self.output_file_path, "w", encoding="utf-8") logger.info(f"Writing partial manifest to: `{self.output_file_path}`") + if self.max_speaker_items: + logger.info(f"Max speaker items limit set to: {self.max_speaker_items}") + else: + logger.info("No max speaker items limit set (potential OOM risk for very large speakers)") logger.debug(f"Expected speaker counts for model: {self.speaker_expected_counts}") def forward(self, batch): @@ -295,46 +312,80 @@ def _process_and_flush_speakers_local(self): self.total_accumulated_items -= len(speaker_items) self.processed_speakers_set.add(speaker_id) - # NOTE: Potential OOM (Out Of Memory) risk if a single speaker has an extremely large - # number of segments (e.g., tens of thousands). The N x N similarity matrix calculated below - # (where N = len(speaker_items)) can consume significant CPU RAM. - # For example, 50,000 segments for one speaker could lead to a float32 similarity matrix - # requiring approximately 10 GB of RAM. Consider this if processing datasets with - # speakers having a very high number of utterances. - embeddings = torch.stack([item['embedding'] for item in speaker_items]) - embeddings_norm = torch.nn.functional.normalize(embeddings, p=2, dim=1) - similarity_matrix = torch.matmul(embeddings_norm, embeddings_norm.transpose(0, 1)) - similarity_matrix.fill_diagonal_(-2.0) # cosine similarity range is [-1, 1] + # Apply speaker size limit to prevent OOM while processing all items + all_items_to_process = speaker_items # We want to process ALL items + + # Create context pool with original indices for easy identification + if self.max_speaker_items and len(speaker_items) > self.max_speaker_items: + logger.warning( + f"Speaker {speaker_id} has {len(speaker_items)} items, exceeding max limit of {self.max_speaker_items}. " + f"Using random sample of {self.max_speaker_items} items as context pool, but processing all {len(speaker_items)} items." + ) + # Randomly sample with original indices preserved + random.seed(12345) # For reproducibility + indexed_items = [(idx, item) for idx, item in enumerate(speaker_items)] + sampled_indexed_items = random.sample(indexed_items, self.max_speaker_items) + context_pool_items = [item for _, item in sampled_indexed_items] + context_pool_original_indices = [idx for idx, _ in sampled_indexed_items] + else: + # Use all items as context pool + context_pool_items = speaker_items + context_pool_original_indices = list(range(len(speaker_items))) + + # NOTE: Now we compute similarities between ALL items and the context pool. + # This limits the similarity matrix to N×M instead of N×N where M <= max_speaker_items. + # Memory usage: N×M×4 bytes instead of N×N×4 bytes. + all_embeddings = torch.stack([item['embedding'] for item in all_items_to_process]) + context_embeddings = torch.stack([item['embedding'] for item in context_pool_items]) + + all_embeddings_norm = torch.nn.functional.normalize(all_embeddings, p=2, dim=1) + context_embeddings_norm = torch.nn.functional.normalize(context_embeddings, p=2, dim=1) + + # Compute N×M similarity matrix: each row is similarities for one item against all context candidates + similarity_matrix = torch.matmul(all_embeddings_norm, context_embeddings_norm.transpose(0, 1)) + + # Mask positions where items are identical (same item appearing in both N and M sets) + # Using original indices as identifiers. This prevents an item from being selected as its own context. + # Create a mapping from original indices to context pool positions + original_index_to_context_position = {} + for context_pos, original_idx in enumerate(context_pool_original_indices): + original_index_to_context_position[original_idx] = context_pos + + # Mask similarities for identical items + for n_idx in range(len(all_items_to_process)): + if n_idx in original_index_to_context_position: + context_pos = original_index_to_context_position[n_idx] + similarity_matrix[n_idx, context_pos] = MASKED_SIMILARITY_VALUE # Sort all similarities for each item to iterate through candidates - # best_similarities_tensor will contain sorted similarities for each row (original item) - # best_indices_tensor will contain original indices of these sorted items + # sorted_similarities_tensor will contain sorted similarities for each row (original item) + # sorted_indices_tensor will contain indices in the context_pool sorted_similarities_tensor, sorted_indices_tensor = torch.sort(similarity_matrix, dim=1, descending=True) - record_preparation_start_time = time.time() num_records_written_for_speaker = 0 # Initialize a counter for items discarded for this specific speaker num_discarded_for_this_speaker_no_context = 0 - for i, current_item_data in enumerate(speaker_items): + for i, current_item_data in enumerate(all_items_to_process): output_record = current_item_data['metadata'].copy() write_this_record = False - # Iterate through potential candidates, sorted by similarity + # Iterate through potential candidates from context pool, sorted by similarity for candidate_rank in range(sorted_indices_tensor.size(1)): candidate_ssim = sorted_similarities_tensor[i, candidate_rank].item() - original_candidate_idx = sorted_indices_tensor[i, candidate_rank].item() + context_pool_idx = sorted_indices_tensor[i, candidate_rank].item() - # Skip if candidate is the item itself (safeguard) - if original_candidate_idx == i: - continue + # if ANY candidate has similarity ≤ MASKED_SIMILARITY_VALUE, all subsequent ones will be ≤ MASKED_SIMILARITY_VALUE + # since similarities are sorted in descending order, we can break early + if candidate_ssim <= MASKED_SIMILARITY_VALUE: + break - # If SSIM is below threshold, stop searching for this item (since candidates are sorted) + # If SSIM is below threshold, stop searching for this item if candidate_ssim < self.context_min_ssim: break # Check duration if SSIM is acceptable - best_meta_dict = speaker_items[original_candidate_idx]['metadata'] + best_meta_dict = context_pool_items[context_pool_idx]['metadata'] candidate_duration = best_meta_dict["duration"] if candidate_duration >= self.context_min_duration: @@ -563,6 +614,12 @@ def main(): parser.add_argument( "--context-min-ssim", type=float, default=0.6, help="Minimum cosine similarity for a context audio segment." ) + parser.add_argument( + "--max-speaker-items", + type=int, + default=None, + help="Maximum size of context pool per speaker to prevent OOM. If a speaker has more items, a random sample will be used as context pool, but all items will still be processed. Default: None (no limit, potential OOM risk).", + ) parser.add_argument("--devices", type=int, default=-1) parser.add_argument("--num-nodes", type=int, default=1) parser.add_argument("--batch-size", type=int, default=16) @@ -837,6 +894,7 @@ def main(): context_min_ssim=args.context_min_ssim, speaker_expected_counts_map=my_speaker_expected_counts, initial_assigned_count=len(assigned_records_for_this_rank), + max_speaker_items=args.max_speaker_items, ) logger.info( f"Starting prediction with {len(assigned_records_for_this_rank)} records ({len(my_speaker_expected_counts)} unique speakers for this rank according to counts)." From 0b485374994a64bf490acb2969a1bd14d04b4f2f Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Wed, 1 Oct 2025 12:48:19 -0700 Subject: [PATCH 093/113] Add new spectral codec definition (#14794) * [TTS] Add new spectral codec definition Signed-off-by: Ryan * Add codec MMD loss definitions Signed-off-by: Ryan * Apply isort and black reformatting Signed-off-by: rlangman --------- Signed-off-by: Ryan Signed-off-by: rlangman Signed-off-by: Jason Co-authored-by: Jason Co-authored-by: rlangman --- nemo/collections/common/parts/utils.py | 6 +- .../tts/losses/audio_codec_loss.py | 173 ++++++++++ nemo/collections/tts/models/audio_codec.py | 35 +- nemo/collections/tts/models/magpietts.py | 46 ++- .../tts/modules/audio_codec_modules.py | 301 ++++++++++++++++-- 5 files changed, 529 insertions(+), 32 deletions(-) diff --git a/nemo/collections/common/parts/utils.py b/nemo/collections/common/parts/utils.py index 6479dc315ba2..fce56eeb4820 100644 --- a/nemo/collections/common/parts/utils.py +++ b/nemo/collections/common/parts/utils.py @@ -159,12 +159,16 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor): class ClampActivation(nn.Module): - def __init__(self, min_value: float = -1.0, max_value: float = 1.0): + def __init__(self, min_value: float = -1.0, max_value: float = 1.0, clamp_training: bool = True): super().__init__() self.min_value = min_value self.max_value = max_value + self.clamp_training = clamp_training def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.training and not self.clamp_training: + return input + return torch.clamp(input, min=self.min_value, max=self.max_value) diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py index e87f4a959443..b970261e30c5 100644 --- a/nemo/collections/tts/losses/audio_codec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -512,3 +512,176 @@ def forward(self, disc_scores_real, disc_scores_gen): loss /= len(disc_scores_real) return loss + + +class MMDLoss(Loss): + """ + Maximum mean discrepancy (MMD) loss, as defined in https://arxiv.org/abs/2406.02315 + + Args: + kernel_radii: List of radii for Gaussian kernels + loss_scale: Constant to multiply loss by + """ + + def __init__(self, kernel_radii=(0.1, 1, 5, 10, 20, 50), loss_scale=1.0): + super().__init__() + self.kernel_radii = kernel_radii + self.loss_scale = loss_scale + + @staticmethod + def _exp_kernel(dxx, r): + return torch.exp((-0.5 / r) * dxx).sum() + + @staticmethod + def _shuffle_codebooks(x): + B, C, _ = x.size() + x_shuffled = torch.zeros_like(x) + for c in range(C): + batch_perm = torch.randperm(B, device=x.device) + x_shuffled[:, c, :] = x[batch_perm, c, :] + return x_shuffled + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'C', 'D'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, inputs): + B, C, D = inputs.size() + + x = inputs + x_mean = x.mean(dim=(0,), keepdim=True) + x_stdev = torch.sqrt(x.var(dim=(0,), keepdim=True) + 1e-8) + x = (x - x_mean) / x_stdev + y = self._shuffle_codebooks(x) + + # [B, C * D] + x = x.reshape([B, C * D]) + y = y.reshape([B, C * D]) + + # [B, B] + xx = torch.mm(x, x.t()) + yy = torch.mm(y, y.t()) + zz = torch.mm(x, y.t()) + + rx = xx.diag().unsqueeze(0).expand_as(xx) + ry = yy.diag().unsqueeze(0).expand_as(yy) + + dxx = rx.t() + rx - 2.0 * xx + dyy = ry.t() + ry - 2.0 * yy + dxy = rx.t() + ry - 2.0 * zz + + loss = 0.0 + coeff = -2.0 / B**2 + denom = B * (B - 1) + for r in self.kernel_radii: + loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dxx, r) - B) / denom + loss += coeff * torch.utils.checkpoint.checkpoint(self._exp_kernel, dxy, r) + loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dyy, r) - B) / denom + + loss = loss.clamp(min=0) + loss = self.loss_scale * loss + return loss + + +class MMDCodebookLoss(Loss): + """ + MMD loss which incentivizes independence between codebooks within each timestep. + + Args: + num_codebooks: Number of codebooks. + codebook_dim: Dimension of a single codebook code. + loss_fn: MMDLoss instance. + """ + + def __init__(self, num_codebooks, codebook_dim, loss_fn): + super().__init__() + self.num_codebooks = num_codebooks + self.codebook_dim = codebook_dim + self.loss_fn = loss_fn + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'D', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, inputs): + B, D, T = inputs.size() + + # [B, C, D / C, T] + x = inputs.reshape(B, self.num_codebooks, self.codebook_dim, T) + # [B*T, C, D / C] + x = rearrange(x, 'B C D T -> (B T) C D') + loss = self.loss_fn(inputs=x) + return loss + + +class MMDEmbeddingLoss(Loss): + """ + MMD loss which incentivizes independence between embedding values within each timestep. + + Args: + loss_fn: MMDLoss instance. + """ + + def __init__(self, loss_fn): + super().__init__() + self.loss_fn = loss_fn + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'D', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, inputs): + # [B*T, 1, D] + x = rearrange(inputs, 'B D T -> (B T) D 1') + loss = self.loss_fn(inputs=x) + return loss + + +class MMDTimeLoss(Loss): + """ + MMD loss which incentivizes independence between different timesteps. + + Args: + loss_fn: MMDLoss instance. + """ + + def __init__(self, loss_fn): + super().__init__() + self.loss_fn = loss_fn + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'D', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, inputs): + x = rearrange(inputs, 'B D T -> B T D') + loss = self.loss_fn(inputs=x) + return loss diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 785fa62210d9..c700e4ba3b80 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -143,6 +143,22 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.gen_loss_fn = instantiate(cfg.generator_loss) self.disc_loss_fn = instantiate(cfg.discriminator_loss) + self.mmd_loss_start_epoch = cfg.get("mmd_loss_start_epoch", 0) + + if "mmd_loss" in cfg: + self.mmd_loss_fn = instantiate(cfg.mmd_loss) + self.mmd_loss_scale = cfg.get("mmd_loss_scale", 1.0) + else: + self.mmd_loss_fn = None + self.mmd_loss_scale = None + + if "mmd_time_loss" in cfg: + self.mmd_time_loss_fn = instantiate(cfg.mmd_time_loss) + self.mmd_time_loss_scale = cfg.get("mmd_time_loss_scale", 1.0) + else: + self.mmd_time_loss_fn = None + self.mmd_time_loss_scale = None + feature_loss_type = cfg.get("feature_loss_type", "relative") if feature_loss_type == "relative": self.feature_loss_fn = RelativeFeatureMatchingLoss() @@ -497,7 +513,7 @@ def _process_batch(self, batch): encoded = encoded.to(self.dtype) # make sure vector quantizer output is in the model dtype audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len) - return audio, audio_len, audio_gen, commit_loss + return audio, audio_len, audio_gen, commit_loss, encoded @property def disc_update_prob(self) -> float: @@ -514,7 +530,7 @@ def should_update_disc(self, batch_idx) -> bool: def training_step(self, batch, batch_idx): optim_gen, optim_disc = self.optimizers() - audio, audio_len, audio_gen, commit_loss = self._process_batch(batch) + audio, audio_len, audio_gen, commit_loss, codes = self._process_batch(batch) metrics = { "global_step": self.global_step, @@ -578,6 +594,19 @@ def training_step(self, batch, batch_idx): metrics["g_loss_commit"] = commit_loss generator_losses.append(self.commit_loss_scale * commit_loss) + if self.mmd_loss_scale: + loss_mmd = self.mmd_loss_fn(inputs=codes) + metrics["g_loss_mmd"] = loss_mmd + + if self.current_epoch >= self.mmd_loss_start_epoch: + generator_losses.append(self.mmd_loss_scale * loss_mmd) + + if self.mmd_time_loss_scale: + loss_mmd_time = self.mmd_time_loss_fn(inputs=codes) + metrics["g_loss_mmd_time"] = loss_mmd_time + if self.current_epoch >= self.mmd_loss_start_epoch: + generator_losses.append(self.mmd_time_loss_scale * loss_mmd_time) + # compute embeddings for speaker consistency loss if self.use_scl_loss: # concate generated and GT waveforms @@ -623,7 +652,7 @@ def on_train_epoch_end(self): self.update_lr("epoch") def validation_step(self, batch, batch_idx): - audio, audio_len, audio_gen, _ = self._process_batch(batch) + audio, audio_len, audio_gen, _, _ = self._process_batch(batch) loss_mel_l1, loss_mel_l2 = self.mel_loss_fn( audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index c719d5f9b1ab..e2d9f077424a 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -36,6 +36,7 @@ from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 from nemo.collections.tts.modules.aligner import AlignmentEncoder +from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter from nemo.collections.tts.modules.magpietts_modules import ( CharAwareSubwordEncoder, EOSDetectionMethod, @@ -95,17 +96,32 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # load codec codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False) + self.sample_rate = codec_model.sample_rate + self.codec_model_samples_per_frame = codec_model.samples_per_frame # del codec discriminator to free memory del codec_model.discriminator - # Set up codebook configuration - self.num_audio_codebooks = codec_model.num_codebooks - self.codec_model_samples_per_frame = codec_model.samples_per_frame + # When using FSQ tokens, the codebook structure can be changed at any time. + # An FSQ definition can be provided in `vector_quantizer` config to train with a codebook structure + # that is different than in the audio codec checkpoint. + vector_quantizer = cfg.get('vector_quantizer') + if vector_quantizer is not None: + vector_quantizer = instantiate(vector_quantizer) + self.num_audio_codebooks = vector_quantizer.num_codebooks + self.codebook_size = vector_quantizer.codebook_size + codec_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=vector_quantizer, + ) + else: + self.num_audio_codebooks = codec_model.num_codebooks + self.codebook_size = codec_model.codebook_size + codec_converter = None + # Our codebooks start with actual audio codec tokens, followed by special tokens. # The `forced_*` options are for backward compatibility for models trained with older code. - num_audio_tokens = codec_model.codebook_size - get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=num_audio_tokens) + get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=self.codebook_size) self.audio_bos_id = cfg.get('forced_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_BOS)) self.audio_eos_id = cfg.get('forced_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_EOS)) self.context_audio_bos_id = cfg.get( @@ -116,7 +132,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): ) self.mask_token_id = cfg.get('forced_mask_token_id', get_token_index(SpecialAudioToken.MASK_TOKEN)) self.num_all_tokens_per_codebook = cfg.get( - 'forced_num_all_tokens_per_codebook', num_audio_tokens + len(SpecialAudioToken) + 'forced_num_all_tokens_per_codebook', self.codebook_size + len(SpecialAudioToken) ) self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False) @@ -201,6 +217,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # This needs to happen after super().__init__() self._codec_model = codec_model self._codec_model.freeze() # Lightning does requires_grad = False and self.eval() + self._codec_converter = codec_converter audio_embeddings = [] for _ in range(self.num_audio_codebooks * self.frame_stacking_factor): @@ -450,6 +467,8 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): self._codec_model.eval() with torch.no_grad(), torch.autocast(device_type=audio.device.type, dtype=torch.float32): codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len) + if self._codec_converter is not None: + codes = self._codec_converter.convert_original_to_new(audio_tokens=codes, audio_lens=codes_len) # Add a timestep to begining and end of codes tensor bos_tensor = torch.full( (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device @@ -478,6 +497,10 @@ def codes_to_audio(self, codes, codes_len): codes_copy[codes == self.audio_bos_id] = 0 # zero is the padding token codes_copy[codes == self.audio_eos_id] = 0 # Pass the modified integer token IDs + if self._codec_converter is not None: + codes_copy = self._codec_converter.convert_new_to_original( + audio_tokens=codes_copy, audio_lens=codes_len + ) audio, audio_len = self._codec_model.decode(tokens=codes_copy, tokens_len=codes_len) # audio: (B, T) # audio_len: (B,) @@ -744,7 +767,7 @@ def clear_forbidden_logits(self, logits, forbid_audio_eos=False): logits[ :, :, - SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos), + SpecialAudioToken.get_forbidden_tokens(self.codebook_size, forbid_audio_eos=forbid_audio_eos), ] = float('-inf') return logits @@ -1276,6 +1299,10 @@ def prepare_context_tensors(self, batch): if 'context_audio_codes' in batch: context_audio_codes = batch['context_audio_codes'] context_audio_codes_lens = batch['context_audio_codes_lens'] + if self._codec_converter is not None: + context_audio_codes = self._codec_converter.convert_original_to_new( + audio_tokens=context_audio_codes, audio_lens=context_audio_codes_lens + ).long() else: context_audio_codes, context_audio_codes_lens = self.audio_to_codes( batch['context_audio'], batch['context_audio_lens'], audio_type='context' @@ -1498,6 +1525,10 @@ def process_batch(self, batch, mode="train"): else: audio_codes = batch['audio_codes'] audio_codes_lens = batch['audio_codes_lens'] + if self._codec_converter: + audio_codes = self._codec_converter.convert_original_to_new( + audio_tokens=audio_codes, audio_lens=audio_codes_lens + ).long() if self.frame_stacking_factor > 1: # repeat the BOS token to frame_stacking_factor times. This is necessary since at inference # we need to start autoregressive generation from a full stack indicating BOS. @@ -2326,6 +2357,7 @@ def infer_batch( all_codes_next_argmax = self.sample_codes_from_logits( all_code_logits_t, temperature=0.01, + topk=1, unfinished_items=unfinished_items, finished_items=finished_items, forbid_audio_eos=forbid_audio_eos, diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index ad8d27f0178d..e84b64cc1ff1 100755 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -31,9 +31,9 @@ from nemo.core.neural_types.elements import ( AudioSignal, EncodedRepresentation, - Index, LengthsType, MelSpectrogramType, + TokenIndex, VoidType, ) from nemo.core.neural_types.neural_type import NeuralType @@ -512,6 +512,7 @@ def __init__( kernel_size: int, stride: int = 1, groups: int = None, + activation: Optional[str] = None, trim_right_ratio: int = 1, bias=True, ): @@ -524,6 +525,11 @@ def __init__( self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias) + if activation is not None: + self.activation = CodecActivation(activation=activation, channels=out_channels) + else: + self.activation = nn.Identity() + kernel_size = self.conv.kernel_size[0] stride = self.conv.stride[0] padding_total = kernel_size - stride @@ -552,6 +558,7 @@ def forward(self, inputs, input_len): # unpad end = hidden_states.shape[-1] - self.padding_right hidden_states = hidden_states[..., self.padding_left : end] + hidden_states = self.activation(hidden_states) # mask hidden_states = mask_sequence_tensor(hidden_states, input_len) return hidden_states @@ -568,6 +575,7 @@ def __init__( stride: int = 1, dilation: int = 1, groups: int = 1, + activation: Optional[str] = None, pad_mode: str = "zeros", extra_pad_mode: str = "constant", bias: bool = True, @@ -592,6 +600,10 @@ def __init__( bias=bias, padding_mode=pad_mode, ) + if activation is not None: + self.activation = CodecActivation(activation=activation, channels=out_channels) + else: + self.activation = nn.Identity() kernel_size = self.conv.kernel_size[0] stride = torch.tensor(self.conv.stride[0], dtype=torch.int64) @@ -649,6 +661,7 @@ def forward(self, inputs, input_len): # Left padding for causal hidden_states = self._pad1d(inputs, (self.padding_total, extra_padding), mode=self.extra_pad_mode) hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) # mask output hidden_states = mask_sequence_tensor(hidden_states, input_len) @@ -711,7 +724,15 @@ def forward(self, inputs, input_len): class ConvTranspose1dNorm(NeuralModule): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, groups: int = 1): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + activation: Optional[str] = None, + ): super().__init__() padding, output_padding = get_up_sample_padding(kernel_size, stride) conv = nn.ConvTranspose1d( @@ -726,6 +747,11 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride ) self.conv = nn.utils.parametrizations.weight_norm(conv) + if activation is not None: + self.activation = CodecActivation(activation=activation, channels=out_channels) + else: + self.activation = nn.Identity() + @property def input_types(self): return { @@ -745,6 +771,7 @@ def remove_weight_norm(self): @typecheck() def forward(self, inputs, input_len): out = self.conv(inputs) + out = self.activation(out) out = mask_sequence_tensor(out, input_len) return out @@ -1147,7 +1174,7 @@ def input_types(self): def output_types(self): return { "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), - "indices": NeuralType(('D', 'B', 'T'), Index()), + "indices": NeuralType(('D', 'B', 'T'), TokenIndex()), } @typecheck() @@ -1160,7 +1187,7 @@ def forward(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch. "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + output_types={"indices": NeuralType(('D', 'B', 'T'), TokenIndex())}, ) @abstractmethod def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: @@ -1168,7 +1195,7 @@ def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: @typecheck( input_types={ - "indices": NeuralType(('D', 'B', 'T'), Index()), + "indices": NeuralType(('D', 'B', 'T'), TokenIndex()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ @@ -1285,7 +1312,7 @@ def compress(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tenso "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"codes": NeuralType(('B', 'D', 'T'), Index())}, + output_types={"codes": NeuralType(('B', 'D', 'T'), TokenIndex())}, ) def inputs_to_codes(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: # apply compression @@ -1347,7 +1374,7 @@ def forward( "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), }, - output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + output_types={"indices": NeuralType(('D', 'B', 'T'), TokenIndex())}, ) def encode(self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor: """Convert a continuous code vector to a single index.""" @@ -1356,7 +1383,7 @@ def encode(self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None) @typecheck( input_types={ - "indices": NeuralType(('D', 'B', 'T'), Index()), + "indices": NeuralType(('D', 'B', 'T'), TokenIndex()), "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), }, output_types={ @@ -1460,7 +1487,7 @@ def forward(self, inputs, input_len): "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + output_types={"indices": NeuralType(('D', 'B', 'T'), TokenIndex())}, ) def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: """Input is split into groups, each group is encoded separately, then the results are concatenated.""" @@ -1478,7 +1505,7 @@ def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: @typecheck( input_types={ - "indices": NeuralType(('D', 'B', 'T'), Index()), + "indices": NeuralType(('D', 'B', 'T'), TokenIndex()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ @@ -1505,7 +1532,7 @@ def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ - "indices": NeuralType(('B', 'D', 'T'), Index()), + "indices": NeuralType(('B', 'D', 'T'), TokenIndex()), }, ) def codes_to_indices(self, codes: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: @@ -1610,7 +1637,10 @@ class ResidualBlockV2(NeuralModule): channels: Input dimension. filters: Number of channels in the residual convolutions. kernel_size: Kernel size of the residual convolutions. - activation: Name of activation function. + activation: Activation to apply in between residual convolutions. + is_causal: Whether to use causal convolutions. + pad_mode: Type of padding to use for conv1d layers. + See https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html """ def __init__( @@ -1619,13 +1649,28 @@ def __init__( filters: int, kernel_size: int = 3, activation: str = "lrelu", + is_causal: bool = False, + pad_mode: str = "reflect", ): super(ResidualBlockV2, self).__init__() - self.input_conv = Conv1dNorm( - in_channels=channels, out_channels=filters, kernel_size=kernel_size, activation=activation - ) - self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) + if not is_causal: + self.input_conv = Conv1dNorm( + in_channels=channels, + out_channels=filters, + kernel_size=kernel_size, + activation=activation, + pad_mode=pad_mode, + ) + self.skip_conv = Conv1dNorm( + in_channels=filters, out_channels=channels, kernel_size=kernel_size, pad_mode=pad_mode + ) + else: + self.input_conv = CausalConv1dNorm( + in_channels=channels, out_channels=filters, kernel_size=kernel_size, activation=activation + ) + self.skip_conv = CausalConv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) + self.output_activation = CodecActivation(activation=activation, channels=channels) def remove_weight_norm(self): @@ -1646,6 +1691,7 @@ def forward(self, inputs, input_len): res = self.skip_conv(inputs=res, input_len=input_len) out = inputs + res out = self.output_activation(out) + out = mask_sequence_tensor(out, lengths=input_len) return out @@ -2585,6 +2631,7 @@ def __init__( kernel_size: int, activation: str, down_sample_rate: int, + pad_mode: str, ): super(STFTResidualBlock, self).__init__() down_sample_kernel_size = down_sample_rate * 2 + 1 @@ -2596,6 +2643,7 @@ def __init__( kernel_size=down_sample_kernel_size, stride=self.down_sample_rate, activation=activation, + pad_mode=pad_mode, ) n_fft, hop_length, win_length = resolution @@ -2605,7 +2653,7 @@ def __init__( self.spec_act = CodecActivation(activation=activation, channels=filters) self.res_block = ResidualBlockV2( - channels=filters, filters=filters, kernel_size=kernel_size, activation=activation + channels=filters, filters=filters, kernel_size=kernel_size, activation=activation, pad_mode=pad_mode ) def remove_weight_norm(self): @@ -2661,6 +2709,7 @@ def __init__( kernel_size: int, activation: str, down_sample_rate: int, + pad_mode: str, ): super(DownSampleResidualBlock, self).__init__() down_sample_kernel_size = down_sample_rate * 2 + 1 @@ -2672,6 +2721,7 @@ def __init__( kernel_size=down_sample_kernel_size, stride=self.down_sample_rate, activation=activation, + pad_mode=pad_mode, ) self.res_block = ResidualBlockV2( channels=filters, filters=filters, kernel_size=kernel_size, activation=activation @@ -2702,7 +2752,8 @@ def forward(self, inputs, input_len): class MultiResolutionSTFTEncoder(NeuralModule): """ - Interface for computing log magnitude STFT features. + Encoder which computes log magnitude STFT features at several time resolutions and encodes them into a low + frame-rate representation. Args: out_dim: Dimension of encoder output embedding. @@ -2714,6 +2765,10 @@ class MultiResolutionSTFTEncoder(NeuralModule): The total down sample rate of the encoder will be 2**(len(resolutions)) * product(down_sample_rate_list) kernel_size: Kernel size to use in all convolutions. activation: Name of activation function. + resample_rates: Optional tuple of two integers. If provided, input audio will be resampled from sampling rate + resample_rates[0] to sampling rate resample_rates[1]. + pad_mode: Type of padding to use for conv1d layers. + See https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html """ def __init__( @@ -2725,20 +2780,41 @@ def __init__( down_sample_rate_list: Tuple[int] = (), kernel_size: int = 3, activation: str = "lrelu", + resample_rates: Tuple[int] = (), + pad_mode: str = "replicate", ): super(MultiResolutionSTFTEncoder, self).__init__() assert len(resolutions) >= 1 assert len(resolutions) == len(resolution_filter_list) + if resample_rates: + if not HAVE_TORCHAUDIO: + raise ValueError("Must install torchaudio for resampling.") + + input_sr, encoder_sr = resample_rates + self.resample = torchaudio.transforms.Resample(input_sr, encoder_sr) + self.resample_length_modifier = encoder_sr / input_sr + else: + self.resample = torch.nn.Identity() + self.resample_length_modifier = 1.0 + n_fft, hop_length, win_length = resolutions[0] input_filters = resolution_filter_list[0] input_dim = n_fft // 2 + 1 self.pre_spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) self.pre_conv = Conv1dNorm( - in_channels=input_dim, out_channels=input_filters, kernel_size=kernel_size, activation=activation + in_channels=input_dim, + out_channels=input_filters, + kernel_size=kernel_size, + activation=activation, + pad_mode=pad_mode, ) self.pre_res_block = ResidualBlockV2( - channels=input_filters, filters=input_filters, kernel_size=kernel_size, activation=activation + channels=input_filters, + filters=input_filters, + kernel_size=kernel_size, + activation=activation, + pad_mode=pad_mode, ) input_dim = input_filters self.stft_blocks = nn.ModuleList([]) @@ -2750,6 +2826,7 @@ def __init__( filters=filters, kernel_size=kernel_size, activation=activation, + pad_mode=pad_mode, ) self.stft_blocks.append(stft_block) input_dim = filters @@ -2765,11 +2842,17 @@ def __init__( down_sample_rate=down_sample_rate, kernel_size=kernel_size, activation=activation, + pad_mode=pad_mode, ) self.down_sample_blocks.append(down_sample_block) input_dim = filters - self.post_conv = Conv1dNorm(in_channels=input_dim, out_channels=out_dim, kernel_size=kernel_size) + self.post_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=out_dim, + kernel_size=kernel_size, + pad_mode=pad_mode, + ) def remove_weight_norm(self): self.encoder.remove_weight_norm() @@ -2790,6 +2873,9 @@ def output_types(self): @typecheck() def forward(self, audio, audio_len): + audio = self.resample(audio) + audio_len = torch.round(self.resample_length_modifier * audio_len).int() + encoded, encoded_len = self.pre_spec_processor(audio=audio, audio_len=audio_len) encoded = self.pre_conv(inputs=encoded, input_len=encoded_len) encoded = self.pre_res_block(inputs=encoded, input_len=encoded_len) @@ -2863,3 +2949,176 @@ def convert_new_to_original(self, audio_tokens, audio_lens): codes=audio_codes, input_len=audio_lens ) return audio_tokens_original + + +class ResNetDecoder(NeuralModule): + """ + A residual decoder designed for low-latency. Most processing is done at a low frame-rate (e.g. 50 FPS), while + minimizing the size of the network which upsamples to the final waveform. + + Args: + input_dim: Dimension of decoder input. + input_filters: Size of the first CNN layer applied to the decoder input. + pre_up_sample_rates: Up sample rates to apply prior to main decoder network. + pre_up_sample_filters: Size of residual blocks in first up sampling blocks. + n_hidden_layers: Number of residual blocks in the main decoder network, which processes the latent space at + low frame-rate. + hidden_filters: Size of each rsidual block in the main decoder network. + resblock_up_sample_rates: Up sample rates to apply after main decoder network. + resblock_up_sample_filters: Size of residual blocks in final up sampling blocks. + resblock_up_sample_kernel_size: Kernel size to use in final up sampling blocks. + kernel_size: Kernel size to use in all other CNN layers. + activation: Name of activation to use in residual blocks. + is_causal: Whether to make the decoder causal. + pad_mode: Type of padding to use for conv1d layers. + See https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + """ + + def __init__( + self, + input_dim: int, + input_filters: int, + pre_up_sample_rates: List[int], + pre_up_sample_filters: List[int], + n_hidden_layers: int, + hidden_filters: int, + resblock_up_sample_rates: List[int], + resblock_up_sample_filters: List[int], + resblock_up_sample_kernel_size: int = 7, + kernel_size: int = 3, + activation: str = "half_snake", + is_causal: bool = False, + pad_mode: str = "replicate", + ): + super().__init__() + + assert len(pre_up_sample_rates) == len(pre_up_sample_filters) + assert len(resblock_up_sample_rates) == len(resblock_up_sample_filters) + + if not is_causal: + conv_class = Conv1dNorm + else: + conv_class = CausalConv1dNorm + + if not is_causal: + conv_transpose_class = ConvTranspose1dNorm + else: + conv_transpose_class = CausalConvTranspose1dNorm + + self.pre_conv = conv_class( + in_channels=input_dim, + out_channels=input_filters, + kernel_size=kernel_size, + ) + + in_channels = input_filters + self.pre_up_sample_rates = pre_up_sample_rates + self.pre_resblocks = nn.ModuleList([]) + self.pre_up_sample_layers = nn.ModuleList([]) + for up_sample_rate, filters in zip(self.pre_up_sample_rates, pre_up_sample_filters): + res_block = ResidualBlockV2( + channels=in_channels, + filters=(2 * in_channels), + kernel_size=kernel_size, + activation=activation, + is_causal=is_causal, + pad_mode=pad_mode, + ) + self.pre_resblocks.append(res_block) + conv = conv_transpose_class( + in_channels=in_channels, + out_channels=filters, + kernel_size=(2 * up_sample_rate), + stride=up_sample_rate, + activation=activation, + ) + self.pre_up_sample_layers.append(conv) + + in_channels = filters + + self.conv_layers = nn.ModuleList( + [ + ResidualBlockV2( + channels=in_channels, + filters=hidden_filters, + kernel_size=kernel_size, + activation=activation, + is_causal=is_causal, + pad_mode=pad_mode, + ) + for _ in range(n_hidden_layers) + ] + ) + + self.resblock_up_sample_rates = resblock_up_sample_rates + self.resblock_up_sample_layers = nn.ModuleList([]) + self.resblocks = nn.ModuleList([]) + for up_sample_rate, filters in zip(self.resblock_up_sample_rates, resblock_up_sample_filters): + conv = conv_transpose_class( + in_channels=in_channels, + out_channels=filters, + kernel_size=(2 * up_sample_rate), + stride=up_sample_rate, + activation=activation, + ) + self.resblock_up_sample_layers.append(conv) + res_block = ResidualBlockV2( + channels=filters, + filters=(2 * filters), + kernel_size=resblock_up_sample_kernel_size, + activation=activation, + is_causal=is_causal, + pad_mode=pad_mode, + ) + self.resblocks.append(res_block) + in_channels = filters + + self.post_conv = conv_class( + in_channels=in_channels, out_channels=1, kernel_size=resblock_up_sample_kernel_size, pad_mode=pad_mode + ) + + self.out_activation = ClampActivation(clamp_training=False) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T_encoded'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, inputs, input_len): + + out = self.pre_conv(inputs=inputs, input_len=input_len) + + audio_len = input_len + for pre_up_sample_rate, pre_up_sample_layer, pre_resblock in zip( + self.pre_up_sample_rates, self.pre_up_sample_layers, self.pre_resblocks + ): + out = pre_resblock(inputs=out, input_len=audio_len) + audio_len = pre_up_sample_rate * audio_len + out = pre_up_sample_layer(inputs=out, input_len=audio_len) + + for conv in self.conv_layers: + out = conv(inputs=out, input_len=audio_len) + + for resblock_up_sample_rate, resblock_up_sample_layer, resblock in zip( + self.resblock_up_sample_rates, self.resblock_up_sample_layers, self.resblocks + ): + audio_len = resblock_up_sample_rate * audio_len + out = resblock_up_sample_layer(inputs=out, input_len=audio_len) + out = resblock(inputs=out, input_len=audio_len) + + out = self.post_conv(inputs=out, input_len=audio_len) + out = rearrange(out, 'B 1 T -> B T') + audio = self.out_activation(out) + audio = mask_sequence_tensor(audio, audio_len) + + return audio, audio_len From 0ad7bcbf7eeca11d5fbc4fb695e3082acba1e50c Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Fri, 3 Oct 2025 14:42:27 -0700 Subject: [PATCH 094/113] Docstrings and comments on EOS handling. (#14847) Docstrings and comments This is a follow-up to: https://github.com/NVIDIA-NeMo/NeMo/pull/14761 Also, changed the default value of forbid_audio_eos to False in the get_forbidden_tokens function since permitting EOS is the more common use case. The default value is not was not being used anywhere so this doesn't change behavior. Signed-off-by: Fejgin, Roy * Add docstrings for sampling methods Signed-off-by: Fejgin, Roy * Fix isort issue Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 204 ++++++++++++++---- .../tts/modules/magpietts_modules.py | 7 +- 2 files changed, 168 insertions(+), 43 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index e2d9f077424a..be4b582bd036 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -16,7 +16,7 @@ import random import time from functools import partial -from typing import List, Union +from typing import Dict, List, Optional, Union import numpy as np import soundfile as sf @@ -757,12 +757,15 @@ def code_to_str(code): output_str += c logging.debug(output_str) - def clear_forbidden_logits(self, logits, forbid_audio_eos=False): + def clear_forbidden_logits(self, logits: torch.Tensor, forbid_audio_eos: bool = False) -> torch.Tensor: """ Sets logits of forbidden tokens to `-inf` so they will never be sampled. - Specifically, we forbid sampling of all special tokens except AUDIO_EOS. + Specifically, we forbid sampling of all special tokens except AUDIO_EOS + which is allowed by default. Args: logits: (B, C, num_audio_tokens_per_codebook) + forbid_audio_eos (bool, optional): If True, also forbid AUDIO_EOS tokens + from being sampled. Default: False. """ logits[ :, @@ -773,22 +776,77 @@ def clear_forbidden_logits(self, logits, forbid_audio_eos=False): def local_transformer_sample_maskgit( self, - dec_output, - temperature=0.7, - topk=80, - unfinished_items={}, - finished_items={}, - use_cfg=False, - cfg_scale=1.0, - n_steps=3, - noise_scale=0.0, - fixed_schedule=None, - dynamic_cfg_scale=False, - sampling_type=None, - forbid_audio_eos=False, - ): + dec_output: torch.Tensor, + temperature: float = 0.7, + topk: int = 80, + unfinished_items: Dict[int, bool] = {}, + finished_items: Dict[int, bool] = {}, + use_cfg: bool = False, + cfg_scale: float = 1.0, + n_steps: int = 3, + noise_scale: float = 0.0, + fixed_schedule: Optional[List[int]] = None, + dynamic_cfg_scale: bool = False, + sampling_type: Optional[str] = None, + forbid_audio_eos: bool = False, + ) -> torch.Tensor: """ - Sample codes for one timestep from the local transformer using MaskGit. + Sample audio codes for the current timestep using MaskGit-like iterative + prediction with the local transformer. If frame-stacking is enabled, the + codes for all frames in the stack are sampled, treated as one long sequence. + + The MaskGit process starts with all positions masked and iteratively unmasks the + most confident positions over multiple steps. By "masked" we mean that a + dedicated MASK token is used (as opposed to attention masking). The LT in this + case is a non-causal transformer decoder. At each step the model predicts all + positions at once. Of those predictions, a subset of the most confident + previously-masked positions is kept and unmasked in the next step. The number of + positions that are unmasked at each step is determined by the unmasking + schedule. We support a cosine schedule and a fixed schedule provided by the + user. + + Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG). + + Special handling: + * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled + * forces / forbids EOS for finished / unfinished items respectively + * optionally, globally forbids audio EOS for all items in the batch. + This is useful early in the generation process. + * supports different unmasking methods, see `sampling_type` argument for details. + + Args: + dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size + and E is primary decoder's embedding dimension. + temperature (float, optional): Sampling temperature + topk (int, optional): Number of top-probability tokens to consider in sampling. + unfinished_items (dict, optional): Dictionary containing indices of batch + items that we are confident have not completed generation. For these items, audio EOS + sampling is forbidden. + finished_items (dict, optional): Dictionary containing indices of batch + items that we are confident are completed. For these items, audio EOS sampling + is forced. + use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size + to be doubled with conditional and unconditional outputs from the primary decoder. + cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True. + n_steps (int, optional): Number of iterative refinement steps for MaskGit sampling. + noise_scale (float, optional): Scale factor for noise to add to confidence scores + during sampling (experimental). + fixed_schedule (list, optional): Fixed schedule for number of tokens to unmask at each step. + If None, uses cosine schedule. + dynamic_cfg_scale (bool, optional): Whether to dynamically adjust CFG scale during + sampling (experimental). + sampling_type (str, optional): Type of sampling strategy. Options are: + ["default", "causal", "purity_causal", "purity_default"]. + * Purity refers to "purity sampling" from https://arxiv.org/abs/2304.01515. If "purity" + is not specified, confidence sampling is used as in the original MaskGit paper. + * "default"/"causal": Controls the order of unmasking across frames when frame-stacking is enabled. + If "causal" is specified, frames are unmasked in causal order. "default" + doesn't impose any constraints on the unmasking order. + forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire + batch. + + Returns: + torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor) """ # dec_output: (B, E) device = dec_output.device @@ -893,7 +951,7 @@ def local_transformer_sample_maskgit( cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits logits[:actual_batch_size] = cfg_logits - # Disallow generation of special tokens (except audio EOS which is handled separately) + # Disallow generation of special tokens logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos) # handle unfinished and finished items @@ -955,17 +1013,56 @@ def local_transformer_sample_maskgit( def local_transformer_sample_autoregressive( self, - dec_output, - temperature=0.7, - topk=80, - unfinished_items={}, - finished_items={}, - use_cfg=False, - cfg_scale=1.0, - use_kv_cache=True, - forbid_audio_eos=False, - ): - # dec_output: (B, E) + dec_output: torch.Tensor, + temperature: float = 0.7, + topk: int = 80, + unfinished_items: Dict[int, bool] = {}, + finished_items: Dict[int, bool] = {}, + use_cfg: bool = False, + cfg_scale: float = 1.0, + use_kv_cache: bool = True, + forbid_audio_eos: bool = False, + ) -> torch.Tensor: + """ + Sample audio codes autoregressively across codebooks using the local + transformer. Uses multinomial sampling with temperature, top-k, and + classifier-free guidance (CFG). + + The sequence is initialized with the primary decoder's hidden output as the only + input and is gradually extended a code for one codebook at a time, appending the + sampled code as input sequence for the next step. At the last step the sequence + is `num_codebooks` long. If frame stacking is enabled, codes for all frames in + the stack are sampled as one long sequence and the final sequence length is + `num_codebooks * frame_stacking_factor` codes long. + + Special handling: + * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled + * forces / forbids EOS for finished / unfinished items respectively + * optionally, globally forbids audio EOS (useful early in the generation process) + + Args: + dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size + and E is primary decoder's embedding dimension. + temperature (float, optional): Sampling temperature. + topk (int, optional): Number of top-probability tokens to consider in sampling. + unfinished_items (dict, optional): Dictionary containing indices of batch + items that we are confident have not completed generation. For these items, audio EOS + sampling is forbidden. + finished_items (dict, optional): Dictionary containing indices of batch + items that we are confident are completed. For these items, audio EOS sampling + is forced. + use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size + to be doubled with conditional and unconditional outputs from the primary decoder. + cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True. + use_kv_cache (bool, optional): Whether to use key-value caching in the transformer. + forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire + batch. + + Returns: + torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor) + where B is batch size (or actual_batch_size if use_cfg=True). + """ + self.local_transformer.reset_cache(use_cache=use_kv_cache) dec_output = dec_output.unsqueeze(1) # (B, 1, E) local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) @@ -991,9 +1088,11 @@ def local_transformer_sample_autoregressive( codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 + # Disallow generation of special tokens codebook_logits = self.clear_forbidden_logits( codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos ).squeeze(1) + codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -1028,14 +1127,40 @@ def local_transformer_sample_autoregressive( def sample_codes_from_logits( self, - all_code_logits_t, - temperature=0.7, - topk=80, - unfinished_items={}, - finished_items={}, - forbid_audio_eos=False, - ): - # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep + all_code_logits_t: torch.Tensor, + temperature: float = 0.7, + topk: int = 80, + unfinished_items: Dict[int, bool] = {}, + finished_items: Dict[int, bool] = {}, + forbid_audio_eos: bool = False, + ) -> torch.Tensor: + """ + Sample codes for all codebooks at a given timestep. Uses multinomial sampling + with temperature and top-k. If frame stacking is on (i.e. `frame_stacking_factor + > 1`), this function will sample across the entire frame stack. + + Special handling: + * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled + * forces / forbids EOS for finished / unfinished items respectively + * optionally, globally forbids audio EOS (useful early in the generation process) + + Args: + all_code_logits_t (torch.Tensor): Logits at a given timestep with shape + (B, num_tokens_per_codebook * num_codebooks * frame_stacking_factor) + temperature (float, optional): Sampling temperature + topk (int, optional): Number of top-probability tokens to consider in sampling. + unfinished_items (dict, optional): Dictionary containing indices of batch + items that we are confident have not completed generation. For these items, audio EOS + sampling is forbidden. + finished_items (dict, optional): Dictionary containing indices of batch + items that we are confident are completed. For these items, audio EOS sampling + is forced. + forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire + batch. + + Returns: + torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor). + """ all_preds = [[] for _ in range(self.frame_stacking_factor)] for fs_index in range(self.frame_stacking_factor): for idx in range(self.num_audio_codebooks): @@ -1048,9 +1173,12 @@ def sample_codes_from_logits( for item_idx in finished_items: codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 + + # Disallow generation of special tokens codebook_logits = self.clear_forbidden_logits( codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos ).squeeze(1) + codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index e1c6786ed776..8569b691242f 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -103,15 +103,12 @@ def get_index(token: SpecialAudioToken, base_codebook_size: int): return base_codebook_size + token.value @staticmethod - def get_forbidden_tokens(base_codebook_size: int, forbid_audio_eos: bool = True) -> list[int]: + def get_forbidden_tokens(base_codebook_size: int, forbid_audio_eos: bool = False) -> list[int]: """ Returns a list of token indices that should not be sampled or returned to user. Args: base_codebook_size (int): The size of the codec codebook (which is the first part of the embedding table). - forbid_audio_eos (bool): Whether to forbid the AUDIO_EOS token to be sampled. - * Set to `False` when internally generating tokens in MagpieTTS sampling - * Set to `True` when checking validity of tokens to be returned to user - or given to the codec for decoding + forbid_audio_eos (bool): Whether AUDIO_EOS should be forbidden. Default: False (i.e. allowed). """ all_special_tokens = list(SpecialAudioToken) if not forbid_audio_eos: From 066d622a680f178864fef443df6320b36090228a Mon Sep 17 00:00:00 2001 From: Roy Fejgin Date: Mon, 20 Oct 2025 16:00:31 -0700 Subject: [PATCH 095/113] Inference metrics improvements (#14923) * Inference metrics improvements - Add UTMOSv2 MOS estimation - Track total generated duration per dataset, which we'll use an indicator of speech rate. Signed-off-by: Fejgin, Roy * Add UTMOSv2 installation to CI. Until we add this to our docker image, temporarily install it each time the (relevant) CI tests are run. Signed-off-by: Fejgin, Roy * Add UTMOSv2 to violin plots Signed-off-by: Fejgin, Roy * Disable UTMOSv2 installation to debug container build issues Signed-off-by: Fejgin, Roy * Update requirements_tts.txt install UTMOSv2 with proper syntax And also pin to version 1.2.1. Signed-off-by: Fejgin, Roy * Remove UTMOSv2 install step ... it is now included in the docker image. Signed-off-by: Fejgin, Roy * Specify UTMOSv2 version Signed-off-by: Fejgin, Roy * Address PR comments * Don't run UTMOSv2 calculation for datasets with more than 200 records since it takes too long * Accept device parameter for UTMOSv2 calculator * Change command line options to use "--disable-*" format Signed-off-by: Fejgin, Roy * Compute batch of files in UTMOSv2 Signed-off-by: Fejgin, Roy * Fix isort and flake8 errors Signed-off-by: Fejgin, Roy * Add some logging to debug UTMOSv2 score not found error Signed-off-by: Fejgin, Roy * Normalize file paths for UTMOSv2 score lookup Signed-off-by: Fejgin, Roy * Remove debug code Signed-off-by: Fejgin, Roy * Bugfix + remove utmosv2 dataset size restriction 1. Bugfix: use correct device for UTMOSv2 model 2. Removed restriction of runnin UTMOSv2 only on datasets with fewer than 200 entries. Reason: after optimizing how UTMOSv2 is run it is now much faster; a 2000-utterance dataset only went from 25 minutes to 28 minutes by adding UTMOSv2. Getting the metric is worth the small additional runtime. Signed-off-by: Fejgin, Roy --------- Signed-off-by: Fejgin, Roy --- nemo/collections/tts/modules/utmosv2.py | 81 +++++++++++++++++++ requirements/requirements_tts.txt | 1 + scripts/magpietts/evaluate_generated_audio.py | 41 +++++++++- scripts/magpietts/infer_and_evaluate.py | 18 +++-- 4 files changed, 135 insertions(+), 6 deletions(-) create mode 100644 nemo/collections/tts/modules/utmosv2.py diff --git a/nemo/collections/tts/modules/utmosv2.py b/nemo/collections/tts/modules/utmosv2.py new file mode 100644 index 000000000000..1da212a75da0 --- /dev/null +++ b/nemo/collections/tts/modules/utmosv2.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import utmosv2 +except ImportError: + raise ImportError( + "UTMOSv2 is not installed. Please install it using `pip install git+https://github.com/sarulab-speech/UTMOSv2.git@v1.2.1`." + ) +from typing import Optional +import torch +from threadpoolctl import threadpool_limits + +""" +Uses the UTMOSv2 model to estimate the MOS of a speech audio file. +""" + + +class UTMOSv2Calculator: + """ + Wrapper around UTMOSv2 MOS estimator to make it easy to use. + Args: + device: The device to place the model on. If None, the best available device will be used. + Default is None. + """ + + def __init__(self, device: Optional[str] = None): + if device is None: + device = get_available_device() + self.model = utmosv2.create_model() + self.model.eval() + self.model.to(torch.device(device)) + + def __call__(self, file_path): + """ + Estimate the MOS of the given speech audio file using UTMOSv2. + """ + with torch.inference_mode(): + # UTMOSv2 tends to launch many OpenMP threads which can overload the machine's CPUs + # without actually speeding up prediction. Limit to 4 threads. + with threadpool_limits(limits=4): + mos_score = self.model.predict(input_path=file_path, num_repetitions=1, num_workers=0) + return mos_score + + def process_directory(self, input_dir: str, batch_size: int = 16) -> list[dict[str, str | float]]: + """ + Computes UTMOSv2 scores for all `*.wav` files in the given directory. + Args: + input_dir: The directory containing the audio files. + batch_size: The number of audio files to process in parallel. + Returns: + A list of dictionaries, each containing the file path and the UTMOSv2 score. + """ + with torch.inference_mode(): + # UTMOSV2 tends to launch many of OpenMP threads which overloads the machine's CPUs + # while actually slowing down the prediction. Limit the number of threads here. + with threadpool_limits(limits=1): + results = self.model.predict( + input_dir=input_dir, num_repetitions=1, num_workers=batch_size, batch_size=batch_size + ) + return results + + +def get_available_device(): + """ + Get the best available device (prefer GPU, fallback to CPU). + """ + if torch.cuda.is_available(): + return "cuda:0" # Use first GPU + else: + return "cpu" diff --git a/requirements/requirements_tts.txt b/requirements/requirements_tts.txt index af2a76932779..927f493e652e 100644 --- a/requirements/requirements_tts.txt +++ b/requirements/requirements_tts.txt @@ -13,4 +13,5 @@ pandas pypinyin pypinyin-dict seaborn +utmosv2 @ git+https://github.com/sarulab-speech/UTMOSv2.git@v1.2.1 diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index d846705ff8e5..e780c7108cf6 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -18,6 +18,7 @@ import pprint import string import tempfile +import time from contextlib import contextmanager from functools import partial @@ -32,6 +33,7 @@ from nemo.collections.asr.metrics.wer import word_error_rate_detail from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules.fcd_metric import FrechetCodecDistance +from nemo.collections.tts.modules.utmosv2 import UTMOSv2Calculator def find_generated_files(audio_dir, prefix, extension): @@ -53,6 +55,19 @@ def find_generated_codec_files(audio_dir): return find_generated_files(audio_dir=audio_dir, prefix="predicted_codes", extension=".pt") +def get_wav_file_duration(audio_path: str) -> float: + """ + Get the duration of an WAV file in seconds. + """ + # get extension of the file + extension = os.path.splitext(audio_path)[1] + if extension.lower() != ".wav": + raise ValueError(f"Audio path {audio_path} is not a WAV file") + info = sf.info(audio_path) + seconds = info.frames / info.samplerate + return seconds + + def read_manifest(manifest_path): records = [] with open(manifest_path, 'r') as f: @@ -153,6 +168,18 @@ def extract_embedding(model, extractor, audio_path, device, sv_model_type): return embeddings.squeeze() +def compute_utmosv2_scores(audio_dir, device): + print(f"\nComputing UTMOSv2 scores for files in {audio_dir}...") + start_time = time.time() + utmosv2_calculator = UTMOSv2Calculator(device=device) + utmosv2_scores = utmosv2_calculator.process_directory(audio_dir) + # convert to to a dictionary indexed by file path + utmosv2_scores_dict = {os.path.normpath(item['file_path']): item['predicted_mos'] for item in utmosv2_scores} + end_time = time.time() + print(f"UTMOSv2 scores computed for {len(utmosv2_scores)} files in {end_time - start_time:.2f} seconds\n") + return utmosv2_scores_dict + + def evaluate( manifest_path, audio_dir, @@ -161,6 +188,7 @@ def evaluate( sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", codecmodel_path=None, + with_utmosv2=True, ): audio_file_lists = find_generated_audio_files(generated_audio_dir) records = read_manifest(manifest_path) @@ -216,10 +244,13 @@ def evaluate( print("No codec model provided, skipping FCD metric") fcd_metric = None + if with_utmosv2: + utmosv2_scores = compute_utmosv2_scores(generated_audio_dir, device) filewise_metrics = [] pred_texts = [] gt_texts = [] gt_audio_texts = [] + total_generated_audio_seconds = 0.0 for ridx, record in enumerate(records): gt_audio_filepath = record['audio_filepath'] context_audio_filepath = record.get('context_audio_filepath', None) @@ -235,6 +266,11 @@ def evaluate( if fcd_metric is not None: pred_codes_filepath = codes_file_lists[ridx] + if with_utmosv2: + utmosv2_score = utmosv2_scores[os.path.normpath(pred_audio_filepath)] + else: + utmosv2_score = 0.0 + try: if language == "en": with torch.inference_mode(): @@ -334,6 +370,7 @@ def evaluate( gt_context_ssim_alternate = torch.nn.functional.cosine_similarity( gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0 ).item() + total_generated_audio_seconds += get_wav_file_duration(pred_audio_filepath) filewise_metrics.append( { @@ -353,6 +390,7 @@ def evaluate( 'gt_audio_filepath': gt_audio_filepath, 'pred_audio_filepath': pred_audio_filepath, 'context_audio_filepath': context_audio_filepath, + 'utmosv2': utmosv2_score, } ) @@ -408,7 +446,8 @@ def evaluate( hypotheses=gt_audio_texts, references=gt_texts, use_cer=False )[0] avg_metrics["frechet_codec_distance"] = fcd - + avg_metrics["utmosv2_avg"] = sum([m['utmosv2'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics["total_gen_audio_seconds"] = total_generated_audio_seconds pprint.pprint(avg_metrics) return avg_metrics, filewise_metrics diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index a22c178bcec4..68dfd2123666 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -295,6 +295,7 @@ def run_inference( violin_plot_metrics=['cer', 'pred_context_ssim'], eos_detection_method=None, ignore_finished_sentence_tracking=False, + with_utmosv2=True, ): # Load model if hparams_file is not None and checkpoint_file is not None: @@ -369,6 +370,8 @@ def run_inference( ssim_per_dataset = [] cer_per_dataset = [] all_datasets_filewise_metrics = {} # Store filewise metrics for all datasets for combined violin plot + if (not with_utmosv2) and ('utmosv2' in violin_plot_metrics): + violin_plot_metrics.remove('utmosv2') for dataset in datasets: print(f"Evaluating dataset {dataset}") metrics_n_repeated = [] @@ -388,7 +391,7 @@ def run_inference( if not os.path.exists(all_experiment_csv): with open(all_experiment_csv, "w") as f: - header = "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative" + header = "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,utmosv2,total_gen_audio_seconds" if compute_fcd: header += ",frechet_codec_distance" header += "\n" @@ -535,6 +538,7 @@ def run_inference( sv_model_type=sv_model, asr_model_name=asr_model_name, codecmodel_path=codecmodel_path if compute_fcd else None, + with_utmosv2=with_utmosv2, ) metrics_n_repeated.append(metrics) dataset_filewise_metrics_all_repeats.extend( @@ -552,7 +556,7 @@ def run_inference( json.dump(mean_rtf_metrics, f, indent=4) with open(all_experiment_csv, "a") as f: - data = f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']}" + data = f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']},{metrics['utmosv2_avg']},{metrics['total_gen_audio_seconds']}" if compute_fcd: data += f",{metrics['frechet_codec_distance']}" data += "\n" @@ -582,6 +586,8 @@ def run_inference( 'ssim_gt_context_avg_alternate', 'cer_gt_audio_cumulative', 'wer_gt_audio_cumulative', + 'utmosv2_avg', + 'total_gen_audio_seconds', ] if compute_fcd: metric_keys.append('frechet_codec_distance') @@ -591,13 +597,13 @@ def run_inference( all_experiment_csv_with_ci = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") if not os.path.exists(all_experiment_csv_with_ci): with open(all_experiment_csv_with_ci, "w") as f: - header = "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative" + header = "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,utmosv2_avg,total_gen_audio_seconds" if compute_fcd: header += ",frechet_codec_distance" header += "\n" f.write(header) with open(all_experiment_csv_with_ci, "a") as f: - data = f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']}" + data = f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']},{metrics_mean_ci['utmosv2_avg']},{metrics_mean_ci['total_gen_audio_seconds']}" if compute_fcd: data += f",{metrics_mean_ci['frechet_codec_distance']}" data += "\n" @@ -682,6 +688,7 @@ def main(): parser.add_argument('--clean_up_disk', action='store_true') parser.add_argument('--cer_target', type=float, default=None) parser.add_argument('--ssim_target', type=float, default=None) + parser.add_argument('--disable_utmosv2', action='store_true', help="Disable UTMOSv2 computation") parser.add_argument( '--log_exp_name', action='store_true', @@ -692,7 +699,7 @@ def main(): '--violin_plot_metrics', type=str, nargs='*', - default=['cer', 'pred_context_ssim'], + default=['cer', 'pred_context_ssim', 'utmosv2'], help="Which metrics to add the violin plot.", ) args = parser.parse_args() @@ -744,6 +751,7 @@ def main(): violin_plot_metrics=args.violin_plot_metrics, eos_detection_method=args.eos_detection_method, ignore_finished_sentence_tracking=args.ignore_finished_sentence_tracking, + with_utmosv2=not args.disable_utmosv2, ) # Mode 1: Run inference from provided hparams and checkpoint files From 22be3f4b0e2d45ea6f78f1050673fbb0660d2a0d Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 21 Oct 2025 12:14:27 -0400 Subject: [PATCH 096/113] New ML yaml + changes to allow for Spectral Codec training with text context (#14894) * Add new config Signed-off-by: Jason * update wandb configs Signed-off-by: Jason * update config Signed-off-by: Jason * add separate tokenizer for text condition Signed-off-by: Jason * update codec loading Signed-off-by: Jason * add it tokenizer Signed-off-by: Jason * fix attempt 1 Signed-off-by: Jason * add an additional +1 for dataset Signed-off-by: Jason * Apply isort and black reformatting Signed-off-by: blisc --------- Signed-off-by: Jason Signed-off-by: blisc Co-authored-by: blisc --- .../magpietts_multilingual_v2_lhotse.yaml | 261 ++++++++++++++++++ .../tts/data/text_to_speech_dataset.py | 4 +- nemo/collections/tts/models/magpietts.py | 31 ++- 3 files changed, 286 insertions(+), 10 deletions(-) create mode 100644 examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml new file mode 100644 index 000000000000..3e4bde22e8dd --- /dev/null +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml @@ -0,0 +1,261 @@ +name: Magpie-TTS-ML + +quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. +# Dataset metadata for each manifest +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +train_ds_meta: ??? +val_ds_meta: ??? + +model: + use_lhotse: true + model_type: "decoder_ce" # single_encoder_sv_tts, decoder_context_tts or decoder_pretrain_synthesizer + use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. + text_conditioning_tokenizer_name: text_ce_tokenizer + context_duration_min: 5.0 + context_duration_max: 5.0 + load_cached_codes_if_available: true + prior_scaling_factor: 0.5 + prior_end_step: 12000 + prior_scaledown_start_step: 8000 + indefinite_prior_prob: 0. # If > 0, then prior will be applied after prior_end_step with this probability. + alignment_loss_scale: 0.002 + embedding_dim: 768 + codecmodel_path: ??? + cfg_unconditional_prob: 0.1 + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + + # Local transformer parameters for autoregressive codebook prediction within a frame + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + + text_tokenizers: # Add more languages for multi-lingual TTS + english_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: false + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + spanish_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + locale: es-ES + punct: true + apostrophe: true + pad_with_space: true + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + locale: es-ES + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + german_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + locale: de-DE + punct: true + apostrophe: true + pad_with_space: true + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + locale: 'de-DE' + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + grapheme_case: mixed + grapheme_prefix: '#' + mandarin_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer + punct: true + apostrophe: true + pad_with_space: true + g2p: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: "" + ascii_letter_case: "upper" + french_chartokenizer: + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + hindi_phoneme: + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + italian_phoneme: + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + vietnamese_phoneme: + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + text_ce_tokenizer: + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + + train_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration : ??? # in seconds. Adjust based on your GPU memory. + quadratic_duration: ${quadratic_duration} + use_bucketing: true + num_buckets: 20 + bucket_buffer_size: 20_000 + shuffle_buffer_size: 20_000 + num_cuts_for_bins_estimate: 20_000 + shard_seed: "trng" + drop_last: true + shuffle: true + num_workers: 6 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + + validation_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. + quadratic_duration: ${quadratic_duration} + use_bucketing: false + force_finite: true + drop_last: false + shuffle: false + num_workers: 2 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + encoder: + n_layers: 6 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: true + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + context_encoder: # Only used for multi_encoder_context_tts and decoder_ce + n_layers: 1 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: false + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + decoder: + n_layers: 12 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 1 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: true + xa_d_head: 128 + xa_d_memory: 768 + xa_n_heads: 1 + is_causal: true + apply_norm_to_cond: true + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + make_prior_window_strict: true + + optim: + _target_: torch.optim.AdamW + lr: 2e-4 + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: 32 + max_steps: ??? + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 1 + limit_train_batches: 1_000 + val_check_interval: 1_000 + num_sanity_val_steps: 0 + benchmark: false + use_distributed_sampler: false # required because Lhotse has its own handling + gradient_clip_val: 2.5 + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_wandb_logger: false + wandb_logger_kwargs: + entity: null + project: null + group: null + name: ${name} + resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + resume_if_exists: true + resume_ignore_no_checkpoint: true diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index a64160011606..e7e58cd7bd9b 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -550,7 +550,9 @@ def __getitem__(self, index): example['context_audio_codes_len'] = context_audio_codes_len else: # @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes - context_audio = torch.zeros(self.codec_model_samples_per_frame, dtype=torch.float32) + # @blisc: Added a +1. If we send in exactly 882 samples, then a conv layer complains about padding. + # Adding 883 works. This occurs when we use text context during inference. + context_audio = torch.zeros(self.codec_model_samples_per_frame + 1, dtype=torch.float32) context_audio_len = context_audio.shape[0] example['context_audio'] = context_audio example['context_audio_len'] = context_audio_len diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index be4b582bd036..44d89cdda843 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -94,9 +94,13 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_devices - # load codec - codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False) - + # load codec, disable loading of loss modules not needed during inference + codec_model_cfg = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), return_config=True) + if "use_scl_loss" in codec_model_cfg: + codec_model_cfg.use_scl_loss = False + codec_model = AudioCodecModel.restore_from( + cfg.get('codecmodel_path'), strict=False, override_config_path=codec_model_cfg + ) self.sample_rate = codec_model.sample_rate self.codec_model_samples_per_frame = codec_model.samples_per_frame # del codec discriminator to free memory @@ -108,16 +112,25 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): vector_quantizer = cfg.get('vector_quantizer') if vector_quantizer is not None: vector_quantizer = instantiate(vector_quantizer) - self.num_audio_codebooks = vector_quantizer.num_codebooks - self.codebook_size = vector_quantizer.codebook_size + num_audio_codebooks = vector_quantizer.num_codebooks + codebook_size = vector_quantizer.codebook_size codec_converter = VectorQuantizerIndexConverter( vector_quantizer_original=codec_model.vector_quantizer, vector_quantizer_new=vector_quantizer, ) + data_num_audio_codebooks = codec_model.vector_quantizer.num_codebooks else: - self.num_audio_codebooks = codec_model.num_codebooks - self.codebook_size = codec_model.codebook_size + num_audio_codebooks = codec_model.num_codebooks + data_num_audio_codebooks = num_audio_codebooks + codebook_size = codec_model.codebook_size codec_converter = None + # The dataloader needs to know the number of codebooks that the context codes were stored in + # In the case where there are no context codes saved, and there is no context audio (in the text context path), + # We create a dummy context code tensor that is only [context_BOS, context_EOS] that is repeated for + # data_num_audio_codebooks + self.data_num_audio_codebooks = data_num_audio_codebooks + self.num_audio_codebooks = num_audio_codebooks + self.codebook_size = codebook_size # Our codebooks start with actual audio codec tokens, followed by special tokens. # The `forced_*` options are for backward compatibility for models trained with older code. @@ -2648,7 +2661,7 @@ def get_dataset(self, dataset_cfg, dataset_type): audio_eos_id=self.audio_eos_id, context_audio_bos_id=self.context_audio_bos_id, context_audio_eos_id=self.context_audio_eos_id, - num_audio_codebooks=self.num_audio_codebooks, + num_audio_codebooks=self.data_num_audio_codebooks, codec_model_samples_per_frame=self.codec_model_samples_per_frame, prior_scaling_factor=self.cfg.prior_scaling_factor, load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, @@ -2678,7 +2691,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D audio_eos_id=self.audio_eos_id, context_audio_bos_id=self.context_audio_bos_id, context_audio_eos_id=self.context_audio_eos_id, - num_audio_codebooks=self.num_audio_codebooks, + num_audio_codebooks=self.data_num_audio_codebooks, prior_scaling_factor=self.cfg.prior_scaling_factor, load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) From 3f8e70de7d037a5e3e66500072082d6d2095b271 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 21 Oct 2025 20:18:51 -0400 Subject: [PATCH 097/113] attempt to remove coverage calls (#14963) Signed-off-by: Jason --- .github/actions/test-template/action.yml | 40 ++++++++++++------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/.github/actions/test-template/action.yml b/.github/actions/test-template/action.yml index e2ac1852dec6..2ad268f9d9d2 100644 --- a/.github/actions/test-template/action.yml +++ b/.github/actions/test-template/action.yml @@ -179,13 +179,13 @@ runs: potential_infra_failure=$(cat $DIR/err.log | grep -Eqiw "device" && echo true || echo false) echo "potential_infra_failure=$potential_infra_failure" >> "$GITHUB_OUTPUT" - docker exec nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage combine - docker exec nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage xml - docker cp nemo_container_${{ github.run_id }}_${{ inputs.runner }}:/workspace/.coverage $DIR/.coverage - docker cp nemo_container_${{ github.run_id }}_${{ inputs.runner }}:/workspace/coverage.xml $DIR/coverage.xml + # docker exec nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage combine + # docker exec nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage xml + # docker cp nemo_container_${{ github.run_id }}_${{ inputs.runner }}:/workspace/.coverage $DIR/.coverage + # docker cp nemo_container_${{ github.run_id }}_${{ inputs.runner }}:/workspace/coverage.xml $DIR/coverage.xml - coverage_report=coverage-${{ steps.create.outputs.coverage-prefix }}-${{ github.run_id }}-$(uuidgen) - echo "coverage_report=$coverage_report" >> "$GITHUB_OUTPUT" + # coverage_report=coverage-${{ steps.create.outputs.coverage-prefix }}-${{ github.run_id }}-$(uuidgen) + # echo "coverage_report=$coverage_report" >> "$GITHUB_OUTPUT" IS_SUCCESS=$(tail -n 1 $DIR/err.log | grep -q "Finished successfully." && echo "true" || echo "false") @@ -201,20 +201,20 @@ runs: exit $EXIT_CODE - - name: Test coverage - shell: bash -x -e -u -o pipefail {0} - run: | - docker exec -t nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage report -i - - - name: Upload artifacts - uses: actions/upload-artifact@v4 - if: ${{ steps.check.outputs.coverage_report != 'none' }} - with: - name: ${{ steps.check.outputs.coverage_report }} - path: | - ${{ github.run_id }}/coverage.xml - ${{ github.run_id }}/.coverage - include-hidden-files: true + # - name: Test coverage + # shell: bash -x -e -u -o pipefail {0} + # run: | + # docker exec -t nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage report -i + + # - name: Upload artifacts + # uses: actions/upload-artifact@v4 + # if: ${{ steps.check.outputs.coverage_report != 'none' }} + # with: + # name: ${{ steps.check.outputs.coverage_report }} + # path: | + # ${{ github.run_id }}/coverage.xml + # ${{ github.run_id }}/.coverage + # include-hidden-files: true - name: Container shutdown if: always() From df85e6da68551697e50d38fd20878e1dd1556bf0 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Wed, 22 Oct 2025 11:15:17 -0700 Subject: [PATCH 098/113] [lhotse][aistore] added support input_cfg.yaml directly from aistore bucket (#14900) * [lhotse][aistore] added support input_cfg.yaml directly from aistore bucket Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * fix: convert pythoon dict obj into DictConf Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * move OmegaConf.create() outside of for loop. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- nemo/collections/common/data/lhotse/cutset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 4261a5488bd2..ef6bbb00b8cc 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -24,6 +24,7 @@ from lhotse import CutSet, Features, Recording from lhotse.array import Array, TemporalArray from lhotse.cut import Cut, MixedCut, PaddingCut +from lhotse.serialization import load_yaml from omegaconf import DictConfig, ListConfig, OmegaConf from nemo.collections.common.data.lhotse.nemo_adapters import ( @@ -343,12 +344,11 @@ def parse_and_combine_datasets( tarred_status = [] if isinstance(config_list, (str, Path)): - # Resolve /path/to/input_cfg.yaml into config contents if needed. - config_list = OmegaConf.load(config_list) + # Resolve local filepath /path/to/input_cfg.yaml or remote url s3://bucket/path/to/input_cfg.yaml into config contents if needed. + config_list = OmegaConf.create(load_yaml(config_list)) assert len(config_list) > 0, "Empty group in dataset config list." for item in config_list: - # Check if we have any attributes that are propagated downwards to each item in the group. # If a key already exists in the item, it takes precedence (we will not overwrite); # otherwise we will assign it. From 239adf4820fc93d0f4303ae003cead2f55fc5179 Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 29 Oct 2025 13:10:00 -0400 Subject: [PATCH 099/113] Update checkpoint saving logic: Use step instead of epoch (#14974) * [lhotse][aistore] added support input_cfg.yaml directly from aistore bucket Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * fix: convert pythoon dict obj into DictConf Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * move OmegaConf.create() outside of for loop. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * update checkpointing logic Signed-off-by: Jason * remove debug Signed-off-by: Jason * update filename in yamls and not in exp_manager Signed-off-by: Jason * undo removal of encoder in config Signed-off-by: Jason --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Jason Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- examples/tts/conf/magpietts/magpietts_dc_en.yaml | 1 + examples/tts/conf/magpietts/magpietts_en.yaml | 1 + examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml | 2 +- examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml | 2 +- examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml | 1 + .../tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml | 1 + 6 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml index 76d6fdde052a..8765af931139 100644 --- a/examples/tts/conf/magpietts/magpietts_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_dc_en.yaml @@ -169,5 +169,6 @@ exp_manager: save_top_k: 5 save_best_model: true always_save_nemo: true + filename: "${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}" resume_if_exists: true resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml index 5d45bbb23764..a68f53cba5d1 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_en.yaml @@ -184,5 +184,6 @@ exp_manager: save_top_k: 5 save_best_model: true always_save_nemo: true + filename: "${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}" resume_if_exists: true resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml index b3060c5b8384..a21e32063dc8 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml @@ -189,6 +189,6 @@ exp_manager: save_top_k: 5 save_best_model: true always_save_nemo: true - filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}-{epoch}' + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' resume_if_exists: true resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml index b1dd69b17723..714a72f4ffa5 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml @@ -189,6 +189,6 @@ exp_manager: save_top_k: 5 save_best_model: true always_save_nemo: true - filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}-{epoch}' + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' resume_if_exists: true resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml index 0e4cf263d469..14529a3172c5 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml @@ -229,5 +229,6 @@ exp_manager: save_top_k: 5 save_best_model: true always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' resume_if_exists: true resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml index 3e4bde22e8dd..b2be1234946e 100644 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml +++ b/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml @@ -257,5 +257,6 @@ exp_manager: save_top_k: 5 save_best_model: true always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' resume_if_exists: true resume_ignore_no_checkpoint: true From fe77a4cd8129e343f5730347ee9505120ae84f88 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 30 Oct 2025 12:44:40 -0700 Subject: [PATCH 100/113] undo changes to ci Signed-off-by: Jason --- .github/actions/test-template/action.yml | 22 ++++++++--------- .github/scripts/components_to_run.py | 2 +- .github/workflows/cicd-main-nemo2.yml | 4 +++- .github/workflows/cicd-main-speech.yml | 8 +++++-- .github/workflows/cicd-main-unit-tests.yml | 28 ++++++++++++++++------ .github/workflows/cicd-main.yml | 5 ++-- .github/workflows/install-test.yml | 4 ++-- 7 files changed, 47 insertions(+), 26 deletions(-) diff --git a/.github/actions/test-template/action.yml b/.github/actions/test-template/action.yml index 5cd03b72d843..8416847c9d12 100644 --- a/.github/actions/test-template/action.yml +++ b/.github/actions/test-template/action.yml @@ -52,11 +52,11 @@ inputs: runs: using: "composite" steps: - # - name: Noop - # shell: bash - # run: | - # chmod -R u+rwX ${{ github.run_id }} - # echo "noop" + - name: Noop + shell: bash + run: | + chmod -R u+rwX ${{ github.run_id }} + echo "noop" - name: Docker system cleanup shell: bash @@ -181,13 +181,13 @@ runs: potential_infra_failure=$(cat $DIR/err.log | grep -Eqiw "device" && echo true || echo false) echo "potential_infra_failure=$potential_infra_failure" >> "$GITHUB_OUTPUT" - # docker exec nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage combine - # docker exec nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage xml - # docker cp nemo_container_${{ github.run_id }}_${{ inputs.runner }}:/workspace/.coverage $DIR/.coverage - # docker cp nemo_container_${{ github.run_id }}_${{ inputs.runner }}:/workspace/coverage.xml $DIR/coverage.xml + docker exec nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage combine + docker exec nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage xml + docker cp nemo_container_${{ github.run_id }}_${{ inputs.runner }}:/workspace/.coverage $DIR/.coverage + docker cp nemo_container_${{ github.run_id }}_${{ inputs.runner }}:/workspace/coverage.xml $DIR/coverage.xml - # coverage_report=coverage-${{ steps.create.outputs.coverage-prefix }}-${{ github.run_id }}-$(uuidgen) - # echo "coverage_report=$coverage_report" >> "$GITHUB_OUTPUT" + coverage_report=coverage-${{ steps.create.outputs.coverage-prefix }}-${{ github.run_id }}-$(uuidgen) + echo "coverage_report=$coverage_report" >> "$GITHUB_OUTPUT" IS_SUCCESS=$(tail -n 1 $DIR/err.log | grep -q "Finished successfully." && echo "true" || echo "false") diff --git a/.github/scripts/components_to_run.py b/.github/scripts/components_to_run.py index 9c9973fdae5e..90ba8c5c10bd 100644 --- a/.github/scripts/components_to_run.py +++ b/.github/scripts/components_to_run.py @@ -69,7 +69,7 @@ def main(source_sha: str, target_sha: str): # Build dependency graph dependencies = nemo_dependencies.build_dependency_graph(nemo_root) - test_modules: List[str] = ["speech"] # Always run speech in magpie branch + test_modules: List[str] = [] for changed_file in changed_files: if changed_file in dependencies: test_modules.extend(dependencies[changed_file]) diff --git a/.github/workflows/cicd-main-nemo2.yml b/.github/workflows/cicd-main-nemo2.yml index a633e8fc3675..c44930a9f207 100644 --- a/.github/workflows/cicd-main-nemo2.yml +++ b/.github/workflows/cicd-main-nemo2.yml @@ -294,8 +294,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} diff --git a/.github/workflows/cicd-main-speech.yml b/.github/workflows/cicd-main-speech.yml index b6bd8bceb2ce..e7d46a5e9875 100644 --- a/.github/workflows/cicd-main-speech.yml +++ b/.github/workflows/cicd-main-speech.yml @@ -65,8 +65,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -200,8 +202,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} diff --git a/.github/workflows/cicd-main-unit-tests.yml b/.github/workflows/cicd-main-unit-tests.yml index ebe596fa90c5..2db945f54ba5 100644 --- a/.github/workflows/cicd-main-unit-tests.yml +++ b/.github/workflows/cicd-main-unit-tests.yml @@ -35,8 +35,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -59,8 +61,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -84,8 +88,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -108,8 +114,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -138,8 +146,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -162,8 +172,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} @@ -187,8 +199,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: ${{ matrix.script }} diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 9213155184a0..3cbd73fe11be 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -20,7 +20,6 @@ on: - main - r** - weekly-bump* - - magpietts_2508 types: [labeled] workflow_dispatch: inputs: @@ -225,9 +224,11 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} - name: main - uses: ./.github/actions/test-template + uses: NVIDIA/NeMo/.github/actions/test-template@main with: runner: ${{ runner.name }} script: L0_Setup_Test_Data_And_Models diff --git a/.github/workflows/install-test.yml b/.github/workflows/install-test.yml index 1b3e982f8f45..f811e70b18fa 100644 --- a/.github/workflows/install-test.yml +++ b/.github/workflows/install-test.yml @@ -2,8 +2,8 @@ name: CI-Install-Check on: pull_request: - branches: - - main + paths: + - "**" concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} From 40f58479b0d2b83b73f83eabcb0520a0a81545d8 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 30 Oct 2025 12:51:39 -0700 Subject: [PATCH 101/113] undo changes to core Signed-off-by: Jason --- nemo/core/classes/common.py | 1 + nemo/core/classes/modelPT.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 0b63bec73f7c..aeb13a8a4828 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -15,6 +15,7 @@ """Interfaces common to all Neural Modules and Models.""" from __future__ import annotations + import copy import hashlib import inspect diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 8ee97d8693ca..12b1d0fc60fb 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -1459,7 +1459,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st if isinstance(cfg.init_from_ptl_ckpt, str): # Restore checkpoint ckpt_path = cfg.pop('init_from_ptl_ckpt') - ckpt = torch.load(ckpt_path, map_location=map_location, weights_only=False) + ckpt = torch.load(ckpt_path, map_location=map_location) # Restore checkpoint into current model self.load_state_dict(ckpt['state_dict'], strict=False) From b84d71ed6f07a487f1a7a66c60ff46aaef08ad79 Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 31 Oct 2025 11:42:59 -0700 Subject: [PATCH 102/113] address copilot comments Signed-off-by: Jason --- .github/workflows/cicd-main-speech.yml | 1 + .../tts/models/magpietts_preference_optimization.py | 1 - nemo/collections/tts/modules/audio_codec_modules.py | 5 ----- scripts/magpietts/eval_squimmos.py | 1 - .../extend_nemo_manifest_with_context_audio.py | 3 +-- scripts/magpietts/infer_and_evaluate.py | 1 - scripts/tts_dataset_to_lhotse/create_shars.py | 13 +------------ tests/collections/tts/modules/test_fcd_metric.py | 2 +- 8 files changed, 4 insertions(+), 23 deletions(-) diff --git a/.github/workflows/cicd-main-speech.yml b/.github/workflows/cicd-main-speech.yml index e7d46a5e9875..a2bb3c2e7da9 100644 --- a/.github/workflows/cicd-main-speech.yml +++ b/.github/workflows/cicd-main-speech.yml @@ -186,6 +186,7 @@ jobs: script: SPEECHLM_HF_Training_DuplexS2SSpeechDecoder - runner: self-hosted-azure script: SPEECHLM_HF_Training_SALM + timeout: 20 - runner: self-hosted-azure script: L2_TTS_Fast_dev_runs_Magpietts_DecoderContext - runner: self-hosted-azure diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index 281a1d7dfe24..a8898f856a8d 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -172,7 +172,6 @@ def test_step(self, batch, batch_idx): for idx in range(predicted_audio.size(0)): if not batch_invalid: - audio_path = predicted_audio_paths[idx] item_idx = batch_idx * test_dl_batch_size + idx pred_transcript = pred_transcripts[idx] gt_transcript = process_text_for_cer(batch['raw_texts'][idx]) diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index e84b64cc1ff1..d133a02dd4e7 100755 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -1452,11 +1452,6 @@ def codebook_size(self): """Returns the size of the codebook for each group.""" return self.fsqs[0].codebook_size - # @property - # def codebook_size(self): - # """Returns the size of the implicit codebook.""" - # return self.codebook_size_per_group**self.num_groups - @property def codebook_dim(self): """Input vector dimension.""" diff --git a/scripts/magpietts/eval_squimmos.py b/scripts/magpietts/eval_squimmos.py index 06a57152314e..af3a02cf381d 100644 --- a/scripts/magpietts/eval_squimmos.py +++ b/scripts/magpietts/eval_squimmos.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License.from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE import argparse -import json import os import librosa diff --git a/scripts/magpietts/extend_nemo_manifest_with_context_audio.py b/scripts/magpietts/extend_nemo_manifest_with_context_audio.py index 98c59eb379fd..fb678f02f8d0 100644 --- a/scripts/magpietts/extend_nemo_manifest_with_context_audio.py +++ b/scripts/magpietts/extend_nemo_manifest_with_context_audio.py @@ -18,7 +18,6 @@ import os import random import re -import time from collections import defaultdict from pathlib import Path @@ -98,7 +97,7 @@ "context_audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav", "context_audio_offset": 5.6, "context_audio_duration": 6.0, - "context_audio_text": "would you mind..", + "context_audio_text": "would you mind..", "context_audio_normalized_text": "would you mind..", "context_audio_speaker_similarity": 0.85 } diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 68dfd2123666..54b1072b6893 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -774,7 +774,6 @@ def main(): checkpoint_file=checkpoint_file, nemo_file=None, ) - return # Mode 2: Run inference from a .nemo file elif args.nemo_files: print(f"Running inference for nemo file: {args.nemo_files}") diff --git a/scripts/tts_dataset_to_lhotse/create_shars.py b/scripts/tts_dataset_to_lhotse/create_shars.py index aa72e7fd16c4..e017e59acbf3 100644 --- a/scripts/tts_dataset_to_lhotse/create_shars.py +++ b/scripts/tts_dataset_to_lhotse/create_shars.py @@ -12,24 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -import csv import json import os -import shutil from pathlib import Path ### from nemo.collections.tts.models import AudioCodecModel -import librosa -import numpy as np -import soundfile as sf -import torch -from lhotse import AudioSource, CutSet, Recording, SupervisionSegment -from lhotse.array import Array, TemporalArray +from lhotse import CutSet, Recording, SupervisionSegment from lhotse.audio import RecordingSet -from lhotse.cut.base import Cut -from lhotse.features.base import Features, FeatureSet from lhotse.shar.writers import AudioTarWriter -from matplotlib import pyplot as plt from tqdm import tqdm @@ -118,7 +108,6 @@ def create_shar_from_manifest(manifest, audio_root_path, out_shar_dir, shard_siz print("...Making Shars") out_shar_dir = Path(out_shar_dir) out_shar_dir.mkdir(parents=True, exist_ok=True) - shard_size = shard_size assert len(user_recordings) % shard_size != 0, "Lhotse breaks if feat_list is a multiple of shard_size" exported = cuts.to_shar(out_shar_dir, fields={"recording": "wav"}, num_jobs=4, shard_size=shard_size) print(f"...share created") diff --git a/tests/collections/tts/modules/test_fcd_metric.py b/tests/collections/tts/modules/test_fcd_metric.py index 25f9a4227934..dfbcbe1f90e0 100644 --- a/tests/collections/tts/modules/test_fcd_metric.py +++ b/tests/collections/tts/modules/test_fcd_metric.py @@ -113,7 +113,7 @@ def test_empty_codes_update(self, metric, device): @pytest.mark.unit def test_codebooks_mismatch_update(self, metric, device, codec): - """Test that the FCD metric doesn't crash when provided with incorrect number ofcodebooks.""" + """Test that the FCD metric doesn't crash when provided with incorrect number of codebooks.""" B = 2 C = codec.num_codebooks - 1 # intentionally missing one codebook T = 10 From f7230c12c90f79af85f66045b81c241f797fa3fe Mon Sep 17 00:00:00 2001 From: Jason Date: Mon, 3 Nov 2025 12:36:03 -0800 Subject: [PATCH 103/113] remove notebooks; address flake8 comments; add import guard in magpie RL; remove some experimental flags Signed-off-by: Jason --- .../text_to_speech/tts_tokenizers.py | 382 +++++++++--------- .../tts/data/text_to_speech_dataset.py | 1 - nemo/collections/tts/data/vocoder_dataset.py | 1 - nemo/collections/tts/g2p/models/i18n_ipa.py | 1 - nemo/collections/tts/models/audio_codec.py | 1 - .../magpietts_preference_optimization.py | 14 +- .../tts/modules/encodec_modules.py | 1 - nemo/collections/tts/modules/fcd_metric.py | 12 +- .../parts/preprocessing/feature_processors.py | 1 - .../tts/parts/preprocessing/features.py | 1 - nemo/collections/tts/parts/utils/callbacks.py | 1 - t5tts_inference.ipynb | 339 ---------------- t5tts_inference_multiturndialogues.ipynb | 320 --------------- 13 files changed, 209 insertions(+), 866 deletions(-) delete mode 100644 t5tts_inference.ipynb delete mode 100644 t5tts_inference_multiturndialogues.ipynb diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index 8e5b0eae9e8f..567dfc6b863b 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -41,19 +41,19 @@ class BaseTokenizer(ABC): + """Abstract class for creating an arbitrary tokenizer to convert string to list of int tokens. + Args: + tokens: List of tokens. + pad: Pad token as string. + blank: Blank token as string. + oov: OOV token as string. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + """ PAD, BLANK, OOV = '', '', '' def __init__(self, tokens, *, pad=PAD, blank=BLANK, oov=OOV, sep='', add_blank_at=None): - """Abstract class for creating an arbitrary tokenizer to convert string to list of int tokens. - Args: - tokens: List of tokens. - pad: Pad token as string. - blank: Blank token as string. - oov: OOV token as string. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - """ super().__init__() tokens = list(tokens) @@ -95,6 +95,17 @@ def decode(self, tokens: List[int]) -> str: class BaseCharsTokenizer(BaseTokenizer): + """Base class for char-based tokenizer. + Args: + chars: string that represents all possible characters. + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + """ # fmt: off # TODO @xueyang: unify definition of the default PUNCT_LIST and import from ipa_lexicon.py PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally @@ -114,17 +125,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=lambda x: x, ): - """Base class for char-based tokenizer. - Args: - chars: string that represents all possible characters. - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - """ tokens = [] self.space, tokens = len(tokens), tokens + [' '] # Space @@ -175,6 +175,17 @@ def encode(self, text): class EnglishCharsTokenizer(BaseCharsTokenizer): + """English char-based tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Basically, it replaces all non-unicode characters with unicode ones and apply lower() function. + """ def __init__( self, punct=True, @@ -184,17 +195,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=english_text_preprocessing, ): - """English char-based tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Basically, it replaces all non-unicode characters with unicode ones and apply lower() function. - """ super().__init__( chars=string.ascii_lowercase, punct=punct, @@ -207,6 +207,17 @@ def __init__( class VietnameseCharsTokenizer(BaseCharsTokenizer): + """Vietnamese grapheme tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. By default, it + would keep any word lowercase. + """ _LOCALE = "vi-VN" _CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="mixed") @@ -221,17 +232,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=vietnamese_text_preprocessing, ): - """Vietnamese grapheme tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. By default, it - would keep any word lowercase. - """ super().__init__( chars=chars, punct=punct, @@ -244,6 +244,17 @@ def __init__( class GermanCharsTokenizer(BaseCharsTokenizer): + """German grapheme-based tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. By default, it + would keep any word unchanged. + """ _LOCALE = "de-DE" _PUNCT_LIST = get_ipa_punctuation_list(_LOCALE) @@ -259,17 +270,6 @@ def __init__( non_default_punct_list=_PUNCT_LIST, text_preprocessing_func=any_locale_text_preprocessing, ): - """German grapheme-based tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. By default, it - would keep any word unchanged. - """ super().__init__( chars=chars, punct=punct, @@ -282,6 +282,15 @@ def __init__( class SpanishCharsTokenizer(BaseCharsTokenizer): + """Spanish grapheme tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + """ PUNCT_LIST = get_ipa_punctuation_list("es-ES") @@ -293,15 +302,6 @@ def __init__( pad_with_space=False, non_default_punct_list=None, ): - """Spanish grapheme tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - """ es_alphabet = "abcdefghijklmnopqrstuvwxyzáéíñóúü" super().__init__( @@ -316,6 +316,15 @@ def __init__( class FrenchCharsTokenizer(BaseCharsTokenizer): + """French grapheme tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + """ PUNCT_LIST = get_ipa_punctuation_list("fr-FR") @@ -327,15 +336,6 @@ def __init__( pad_with_space=False, non_default_punct_list=None, ): - """French grapheme tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - """ fr_alphabet = get_grapheme_character_set(locale="fr-FR", case="lower") super().__init__( @@ -350,20 +350,20 @@ def __init__( class ItalianCharsTokenizer(BaseCharsTokenizer): + """Italian grapheme tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + """ PUNCT_LIST = get_ipa_punctuation_list("it-IT") def __init__( self, punct=True, apostrophe=True, add_blank_at=None, pad_with_space=False, non_default_punct_list=None ): - """Italian grapheme tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - """ it_alphabet = "abcdefghijklmnopqrstuvwxyzàèéìòùó" super().__init__( @@ -378,6 +378,17 @@ def __init__( class GermanPhonemesTokenizer(BaseCharsTokenizer): + """Deutsch phoneme-based tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Currently, it only applies lower() function. + """ # fmt: off PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally ',', '.', '!', '?', '-', @@ -395,17 +406,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=any_locale_text_preprocessing, ): - """Deutsch phoneme-based tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Currently, it only applies lower() function. - """ de_ipa = "abdefhijklmnoprstuvwxyzçðøŋœɐɑɒɔəɛɜɡɪɹɾʃʊʌʒː̃" de_suprasegmentals = "12" @@ -449,6 +449,17 @@ def encode(self, text): class ItalianPhonemesTokenizer(BaseCharsTokenizer): + """Italian phoneme-based tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Currently, it only applies lower() function. + """ # fmt: off PUNCT_LIST = ( ',', '.', '!', '?', '-', @@ -467,17 +478,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=italian_text_preprocessing, ): - """Italian phoneme-based tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Currently, it only applies lower() function. - """ it_ipa = ( "abcdefghijklmnopqrstuvwxyzàèéìòùóæɐɑɔəɚɜɬɹʌʔᵻðŋɛɡɣɪɲɾʃʊʎʒʝβθd͡'t͡'øɒɕɓçɖɘɝɞɟʄɡɠɢʛɦɧħɥʜɨɬɫɮʟɱɯɰɳɵɸœɶʘɺ" @@ -523,6 +523,27 @@ def encode(self, text): class EnglishPhonemesTokenizer(BaseTokenizer): + """English phoneme-based tokenizer. + Args: + g2p: Grapheme to phoneme module. + punct: Whether to reserve grapheme for basic punctuation or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + stresses: Whether to use phonemes codes with stresses (0-2) or not. + chars: Whether to additionally use chars together with phonemes. It is useful if g2p module can return + chars too. + space: Space token as string. + silence: Silence token as string (will be disabled if it is None). + apostrophe: Whether to use apostrophe or not. + oov: OOV token as string. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Basically, it replaces all non-unicode characters with unicode ones. + Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be + handled by g2p). + """ # fmt: off PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally ',', '.', '!', '?', '-', @@ -559,27 +580,6 @@ def __init__( pad_with_space=False, text_preprocessing_func=lambda text: english_text_preprocessing(text, lower=False), ): - """English phoneme-based tokenizer. - Args: - g2p: Grapheme to phoneme module. - punct: Whether to reserve grapheme for basic punctuation or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - stresses: Whether to use phonemes codes with stresses (0-2) or not. - chars: Whether to additionally use chars together with phonemes. It is useful if g2p module can return - chars too. - space: Space token as string. - silence: Silence token as string (will be disabled if it is None). - apostrophe: Whether to use apostrophe or not. - oov: OOV token as string. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Basically, it replaces all non-unicode characters with unicode ones. - Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be - handled by g2p). - """ self.phoneme_probability = None if hasattr(g2p, "phoneme_probability"): @@ -683,8 +683,30 @@ def set_phone_prob(self, prob): self.g2p.phoneme_probability = self.phoneme_probability -@experimental class IPATokenizer(BaseTokenizer): + """General-purpose IPA-based tokenizer. + Args: + g2p: Grapheme to phoneme module, should be IpaG2p or some subclass thereof. + locale: Locale used to determine default text processing logic and punctuation. + Supports ["en-US", "de-DE", "es-ES", "fr-FR"]. Defaults to "en-US". + Specify None if implementing custom logic for a new locale. + punct: Whether to reserve grapheme for basic punctuation or not. + non_default_punct_list: List of punctuation marks which will be used instead default, if any. + fixed_vocab: List of valid grapheme/phoneme tokens for the model. + Set only if overriding the default vocab generation process (reading from G2P dict). + If set, any dataset entries that have unincluded graphemes will be filtered out, and any words whose + pronunciations have unincluded phonemes will be treated as OOV. + Please make sure that the grapheme prefixes and cases are consistent with the G2P module's settings. + Defaults to None, which means default vocab generation is used. + space: Space token as string. + silence: Silence token as string (will be disabled if it is None). + apostrophe: Whether to use apostrophe or not. + oov: OOV token as string. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + """ def __init__( self, g2p, @@ -701,29 +723,6 @@ def __init__( add_blank_at=None, pad_with_space=False, ): - """General-purpose IPA-based tokenizer. - Args: - g2p: Grapheme to phoneme module, should be IpaG2p or some subclass thereof. - locale: Locale used to determine default text processing logic and punctuation. - Supports ["en-US", "de-DE", "es-ES", "fr-FR"]. Defaults to "en-US". - Specify None if implementing custom logic for a new locale. - punct: Whether to reserve grapheme for basic punctuation or not. - non_default_punct_list: List of punctuation marks which will be used instead default, if any. - fixed_vocab: List of valid grapheme/phoneme tokens for the model. - Set only if overriding the default vocab generation process (reading from G2P dict). - If set, any dataset entries that have unincluded graphemes will be filtered out, and any words whose - pronunciations have unincluded phonemes will be treated as OOV. - Please make sure that the grapheme prefixes and cases are consistent with the G2P module's settings. - Defaults to None, which means default vocab generation is used. - space: Space token as string. - silence: Silence token as string (will be disabled if it is None). - apostrophe: Whether to use apostrophe or not. - oov: OOV token as string. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - """ if not hasattr(g2p, "symbols"): logging.error( f"Please make sure the G2P module passed into the IPATokenizer has a `symbols` attribute. " @@ -858,6 +857,25 @@ def set_phone_prob(self, prob): class ChinesePhonemesTokenizer(BaseTokenizer): + """Chinese phoneme-based tokenizer. + Note: This tokenizer for now covers Chinese phonemes/tones and English letters because our dataset contains + both Chinese and English graphemes. + Args: + g2p: Grapheme to phoneme module. + punct: Whether to reserve grapheme for basic punctuation or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + space: Space token as string. + silence: Silence token as string (will be disabled if it is None). + apostrophe: Whether to use apostrophe or not. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Basically, it replaces all non-unicode characters with unicode ones. + Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be + handled by g2p). + """ # fmt: off PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally ',', '.', '!', '?', '-', @@ -865,6 +883,7 @@ class ChinesePhonemesTokenizer(BaseTokenizer): ')', '[', ']', '{', '}', ) ZH_PUNCT_LIST = list(",。?!;:、‘’“”()【】「」《》") + list(PUNCT_LIST) + # fmt: on def __init__( self, @@ -880,25 +899,6 @@ def __init__( pad_with_space=False, text_preprocessing_func=chinese_text_preprocessing, ): - """Chinese phoneme-based tokenizer. - Note: This tokenizer for now covers Chinese phonemes/tones and English letters because our dataset contains - both Chinese and English graphemes. - Args: - g2p: Grapheme to phoneme module. - punct: Whether to reserve grapheme for basic punctuation or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - space: Space token as string. - silence: Silence token as string (will be disabled if it is None). - apostrophe: Whether to use apostrophe or not. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Basically, it replaces all non-unicode characters with unicode ones. - Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be - handled by g2p). - """ tokens = [] self.space, tokens = len(tokens), tokens + [space] # Space @@ -977,6 +977,24 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): class JapanesePhonemeTokenizer(BaseTokenizer): + """Japanese phoneme-based tokenizer. + Note: This tokenizer for now covers Japanese phonemes + Args: + g2p: Grapheme to phoneme module. + punct: Whether to reserve grapheme for basic punctuation or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + space: Space token as string. + silence: Silence token as string (will be disabled if it is None). + apostrophe: Whether to use apostrophe or not. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Basically, it replaces all non-unicode characters with unicode ones. + Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be + handled by g2p). + """ JA_PUNCT_LIST = get_ipa_punctuation_list("ja-JP") @@ -994,24 +1012,6 @@ def __init__( pad_with_space=False, text_preprocessing_func=japanese_text_preprocessing, ): - """Japanese phoneme-based tokenizer. - Note: This tokenizer for now covers Japanese phonemes - Args: - g2p: Grapheme to phoneme module. - punct: Whether to reserve grapheme for basic punctuation or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - space: Space token as string. - silence: Silence token as string (will be disabled if it is None). - apostrophe: Whether to use apostrophe or not. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Basically, it replaces all non-unicode characters with unicode ones. - Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be - handled by g2p). - """ tokens = [] self.space, tokens = len(tokens), tokens + [space] # Space @@ -1089,13 +1089,13 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): # TODO @xueyang: subclassing from `nemo/collections/common/tokenizers/tokenizer_spec.py::TokenizerSpec`, and/or # adjust to reuse `nemo/collections/common/tokenizers/aggregate_tokenizer.py::AggregateTokenizer` class AggregatedTTSTokenizer: + """A simple aggregated tokenizer. Aggregates multiple tokenizers into one by combining (simply concatenating) + their tokens into one vocabulary. + Args: + tokenizers: List of tokenizers to aggregate. + tokenizer_names: List of names for each tokenizer (usually the language identifier). + """ def __init__(self, tokenizers: List[Union[BaseTokenizer, PreTrainedTokenizerBase]], tokenizer_names: List[str]): - """A simple aggregated tokenizer. Aggregates multiple tokenizers into one by combining (simply concatenating) - their tokens into one vocabulary. - Args: - tokenizers: List of tokenizers to aggregate. - tokenizer_names: List of names for each tokenizer (usually the language identifier). - """ assert len(tokenizers) == len(tokenizer_names), "Number of tokenizers and tokenizer names must be the same." tokens = [] tokenizer_offsets = {} diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index e7e58cd7bd9b..29fbe9646848 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -61,7 +61,6 @@ class DatasetSample: tokenizer_names: List[str] = None -@experimental class TextToSpeechDataset(Dataset): """ Class for processing and loading text to speech training examples. diff --git a/nemo/collections/tts/data/vocoder_dataset.py b/nemo/collections/tts/data/vocoder_dataset.py index 55fa9d678edd..b6ea56e2ed7f 100644 --- a/nemo/collections/tts/data/vocoder_dataset.py +++ b/nemo/collections/tts/data/vocoder_dataset.py @@ -111,7 +111,6 @@ def preprocess_manifest( return samples, sample_weights -@experimental class VocoderDataset(Dataset): """ Class for processing and loading Vocoder training examples. diff --git a/nemo/collections/tts/g2p/models/i18n_ipa.py b/nemo/collections/tts/g2p/models/i18n_ipa.py index ed0569eac98d..eb254bfab9f8 100644 --- a/nemo/collections/tts/g2p/models/i18n_ipa.py +++ b/nemo/collections/tts/g2p/models/i18n_ipa.py @@ -31,7 +31,6 @@ from nemo.utils.decorators import experimental -@experimental class IpaG2p(BaseG2p): # fmt: off STRESS_SYMBOLS = ["ˈ", "ˌ"] diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index c700e4ba3b80..1601350457cf 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -52,7 +52,6 @@ HAVE_TORCHAUDIO = False -@experimental class AudioCodecModel(ModelPT): def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Convert to Hydra 1.0 compatible DictConfig diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index a8898f856a8d..84e845cec2f5 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -38,7 +38,13 @@ except ImportError: HAVE_TORCHAUDIO = False -from nemo_text_processing.text_normalization.normalize import Normalizer +try: + from nemo_text_processing.text_normalization.normalize import Normalizer + + PYNINI_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + Normalizer = None + PYNINI_AVAILABLE = False from nemo.collections.tts.models import MagpieTTSModel @@ -69,7 +75,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") self.whisper_model.eval() self._normalize_whisper_transcript = cfg.get('normalize_whisper_transcript', True) - if self._normalize_whisper_transcript: + if self._normalize_whisper_transcript and PYNINI_AVAILABLE: self._normalizer_cache = {} # Pre-create normalizer for the configured language lang = cfg.get('pref_set_language', 'en') @@ -77,6 +83,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): def _get_cached_normalizer(self, lang_key): """Get or create a cached normalizer for the given language.""" + if not PYNINI_AVAILABLE: + return None lang_key = lang_key if lang_key else "en" if lang_key not in self._normalizer_cache: logging.info(f"Creating normalizer for language: {lang_key}") @@ -541,6 +549,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): def _get_cached_normalizer(self, lang_key): """Get or create a cached normalizer for the given language.""" + if not PYNINI_AVAILABLE: + return None lang_key = lang_key if lang_key else "en" if lang_key not in self._normalizer_cache: logging.info(f"Creating normalizer for language: {lang_key}") diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index 377f5671161e..34c35bb105d0 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -537,7 +537,6 @@ def _mask_3d(tensor: Tensor, lengths: Tensor): return tensor * mask -@experimental class EuclideanCodebook(NeuralModule): """ Codebook with Euclidean distance. diff --git a/nemo/collections/tts/modules/fcd_metric.py b/nemo/collections/tts/modules/fcd_metric.py index 2148d9669da1..42751897c906 100644 --- a/nemo/collections/tts/modules/fcd_metric.py +++ b/nemo/collections/tts/modules/fcd_metric.py @@ -47,7 +47,7 @@ def codes_to_embedding(self, x: Tensor, x_len: Tensor) -> Tensor: """ Embeds a batch of audio codec codes into the codec's continuous embedding space. """ - # x: (B, C, T + # x: (B, C, T) # x_len: (B,) return self.codec.dequantize(tokens=x, tokens_len=x_len) @@ -76,7 +76,7 @@ class FrechetCodecDistance(Metric): """ Parts of this are based on the following implementation of FID (Frechet Inception Distance) on images: - + https://github.com/pytorch/torcheval/blob/main/torcheval/metrics/image/fid.py # Copyright (c) Meta Platforms, Inc. and affiliates. @@ -88,14 +88,14 @@ class FrechetCodecDistance(Metric): Contents of original LICENSE file: # BSD License - # + # # For torcheval software # # Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, # are permitted provided that the following conditions are met: - # + # # * Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # @@ -170,12 +170,12 @@ def update(self, codes: Tensor, codes_len: Tensor, is_real: bool): assert codes.ndim == 3 if codes.numel() == 0: - logging.warning(f"\nFCD metric received an empty batch of codes - skipping update\n") + logging.warning("FCD metric received an empty batch of codes - skipping update") return if codes.shape[1] != self.model.codec.num_codebooks: logging.warning( - f"\nFCD metric received a batch of codes of shape {codes.shape}, but the model has {self.model.codec.num_codebooks} codebooks - skipping update\n" + f"FCD metric received a batch of codes of shape {codes.shape}, but the model has {self.model.codec.num_codebooks} codebooks - skipping update" ) return diff --git a/nemo/collections/tts/parts/preprocessing/feature_processors.py b/nemo/collections/tts/parts/preprocessing/feature_processors.py index 19ed8139ae65..afb0fffa5928 100644 --- a/nemo/collections/tts/parts/preprocessing/feature_processors.py +++ b/nemo/collections/tts/parts/preprocessing/feature_processors.py @@ -22,7 +22,6 @@ from nemo.utils.decorators import experimental -@experimental class FeatureProcessor(ABC): @abstractmethod def process(self, training_example: dict) -> None: diff --git a/nemo/collections/tts/parts/preprocessing/features.py b/nemo/collections/tts/parts/preprocessing/features.py index 5067e89f52f8..307f54312214 100644 --- a/nemo/collections/tts/parts/preprocessing/features.py +++ b/nemo/collections/tts/parts/preprocessing/features.py @@ -28,7 +28,6 @@ from nemo.utils.decorators import experimental -@experimental class Featurizer(ABC): @abstractmethod def save(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path, overwrite: bool = True) -> None: diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index 1856dee0ce0f..366ff864ffea 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -127,7 +127,6 @@ def generate_artifacts( """ -@experimental class LoggingCallback(Callback): """ Callback which can log artifacts (eg. model predictions, graphs) to local disk, Tensorboard, and/or WandB. diff --git a/t5tts_inference.ipynb b/t5tts_inference.ipynb deleted file mode 100644 index 875809c6fa94..000000000000 --- a/t5tts_inference.ipynb +++ /dev/null @@ -1,339 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "466ccdc5", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "from nemo.collections.tts.models import T5TTS_Model\n", - "from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset, DatasetSample\n", - "from omegaconf.omegaconf import OmegaConf, open_dict\n", - "import torch\n", - "import os\n", - "import soundfile as sf\n", - "from IPython.display import display, Audio\n", - "import numpy as np\n", - "import os\n", - "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"" - ] - }, - { - "cell_type": "markdown", - "id": "1f5798ac", - "metadata": {}, - "source": [ - "### Checkpoint Paths" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "04445f11", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "hparams_file = \"/datap/misc/experiment_checkpoints/localtransformer/koel_12.5_FPS_causal_13codebooks_codecmodel_context5sec_LTN3_hparams.yaml\"\n", - "checkpoint_file = \"/datap/misc/experiment_checkpoints/localtransformer/koel_12.5_FPS_causal_13codebooks_codecmodel_context5sec_LTN3_epoch101.ckpt\"\n", - "# codecmodel_path = \"/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo\"\n", - "codecmodel_path = \"/datap/misc/checkpoints/12.5_FPS_causal_13codebooks_codecmodel.nemo\"\n", - "\n", - "\n", - "# Temp out dir for saving audios\n", - "out_dir = \"/datap/misc/t5tts_inference_notebook_samples\"\n", - "if not os.path.exists(out_dir):\n", - " os.makedirs(out_dir)" - ] - }, - { - "cell_type": "markdown", - "id": "86bf2a16", - "metadata": {}, - "source": [ - "### Load Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "87bf66f9", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "model_cfg = OmegaConf.load(hparams_file).cfg\n", - "\n", - "with open_dict(model_cfg):\n", - " model_cfg.codecmodel_path = codecmodel_path\n", - " if hasattr(model_cfg, 'text_tokenizer'):\n", - " # Backward compatibility for models trained with absolute paths in text_tokenizer\n", - " model_cfg.text_tokenizer.g2p.phoneme_dict = \"scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt\"\n", - " model_cfg.text_tokenizer.g2p.heteronyms = \"scripts/tts_dataset_files/heteronyms-052722\"\n", - " model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0\n", - " model_cfg.train_ds = None\n", - " model_cfg.validation_ds = None\n", - "\n", - "\n", - "model = T5TTS_Model(cfg=model_cfg)\n", - "print(\"Loading weights from checkpoint\")\n", - "ckpt = torch.load(checkpoint_file)\n", - "model.load_state_dict(ckpt['state_dict'])\n", - "print(\"Loaded weights.\")\n", - "\n", - "model.use_kv_cache_for_inference = True\n", - "\n", - "model.cuda()\n", - "model.eval()" - ] - }, - { - "cell_type": "markdown", - "id": "361b5711", - "metadata": {}, - "source": [ - "### Initialize Dataset class and helper functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "840a7271", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "test_dataset = T5TTSDataset(\n", - " dataset_meta={},\n", - " sample_rate=model_cfg.sample_rate,\n", - " min_duration=0.5,\n", - " max_duration=20,\n", - " codec_model_downsample_factor=model_cfg.codec_model_downsample_factor,\n", - " bos_id=model.bos_id,\n", - " eos_id=model.eos_id,\n", - " context_audio_bos_id=model.context_audio_bos_id,\n", - " context_audio_eos_id=model.context_audio_eos_id,\n", - " audio_bos_id=model.audio_bos_id,\n", - " audio_eos_id=model.audio_eos_id,\n", - " num_audio_codebooks=model_cfg.num_audio_codebooks,\n", - " prior_scaling_factor=None,\n", - " load_cached_codes_if_available=True,\n", - " dataset_type='test',\n", - " tokenizer_config=None,\n", - " load_16khz_audio=model.model_type == 'single_encoder_sv_tts',\n", - " use_text_conditioning_tokenizer=model.use_text_conditioning_encoder,\n", - " pad_context_text_to_max_duration=model.pad_context_text_to_max_duration,\n", - " context_duration_min=model.cfg.get('context_duration_min', 5.0),\n", - " context_duration_max=model.cfg.get('context_duration_max', 5.0),\n", - ")\n", - "test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test')\n", - "\n", - "\n", - "\n", - "def get_audio_duration(file_path):\n", - " with sf.SoundFile(file_path) as audio_file:\n", - " # Calculate the duration\n", - " duration = len(audio_file) / audio_file.samplerate\n", - " return duration\n", - "\n", - "def create_record(text, context_audio_filepath=None, context_text=None):\n", - " dummy_audio_fp = os.path.join(out_dir, \"dummy_audio.wav\")\n", - " dummy_audio = sf.write(dummy_audio_fp, np.zeros(22050 * 3), 22050) # 3 seconds of silence\n", - " record = {\n", - " 'audio_filepath' : dummy_audio_fp,\n", - " 'duration': 3.0,\n", - " 'text': text,\n", - " 'speaker': \"dummy\",\n", - " }\n", - " if context_text is not None:\n", - " assert context_audio_filepath is None\n", - " record['context_text'] = context_text\n", - " else:\n", - " assert context_audio_filepath is not None\n", - " record['context_audio_filepath'] = context_audio_filepath\n", - " record['context_audio_duration'] = get_audio_duration(context_audio_filepath)\n", - " \n", - " return record" - ] - }, - { - "cell_type": "markdown", - "id": "e9aa7a5a", - "metadata": {}, - "source": [ - "### Set transcript and context pairs to test" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b7374d3f", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "import pprint\n", - "# Change sample text and prompt audio/text here\n", - "audio_base_dir = \"/\"\n", - "test_entries = [\n", - " create_record(\n", - " text=\"The recollection of a unique event cannot, so Bergson contends, be wholly constituted by habit, and is in fact something radically different from the memory which is habit.\",\n", - "# text=\"This is a simple sentence to check my text to speech synthesis model.\",\n", - "# text=\" How? \",\n", - "# context_text=\"Speaker and Emotion: | Language:en Dataset:Riva Speaker:Lindy_WIZWIKI |\",\n", - "# Lindy_CMU_FEARFUL\n", - "# Lindy_WIZWIKI\n", - "# context_audio_filepath=\"/datap/misc/LibriTTSfromNemo/LibriTTS/test-clean/7729/102255/7729_102255_000012_000001.wav\", # Supply either context_audio_filepath or context_text, not both\n", - " context_audio_filepath=\"/datap/misc/LibriTTSfromNemo/LibriTTS/test-clean/8230/279154/8230_279154_000004_000009.wav\",\n", - " ),\n", - "]\n", - "\n", - "data_samples = []\n", - "for entry in test_entries:\n", - " dataset_sample = DatasetSample(\n", - " dataset_name=\"sample\",\n", - " manifest_entry=entry,\n", - " audio_dir=audio_base_dir,\n", - " feature_dir=audio_base_dir,\n", - " text=entry['text'],\n", - " speaker=None,\n", - " speaker_index=0,\n", - " tokenizer_names=[\"english_phoneme\"], # Change this for multilingual: \"english_phoneme\", \"spanish_phoneme\", \"english_chartokenizer\", \"german_chartokenizer\".. \n", - " )\n", - " data_samples.append(dataset_sample)\n", - " \n", - "test_dataset.data_samples = data_samples\n", - "\n", - "test_data_loader = torch.utils.data.DataLoader(\n", - " test_dataset,\n", - " batch_size=1,\n", - " collate_fn=test_dataset.collate_fn,\n", - " num_workers=0,\n", - " shuffle=False\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "aab9866b", - "metadata": {}, - "source": [ - "### Generate With Prior" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "745b2ef7", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "item_idx = 0\n", - "for bidx, batch in enumerate(test_data_loader):\n", - " print(\"Processing batch {} out of {}\".format(bidx, len(test_data_loader)))\n", - " model.t5_decoder.reset_cache(use_cache=True)\n", - " batch_cuda ={}\n", - " for key in batch:\n", - " if isinstance(batch[key], torch.Tensor):\n", - " batch_cuda[key] = batch[key].cuda()\n", - " else:\n", - " batch_cuda[key] = batch[key]\n", - " import time\n", - " st = time.time()\n", - " \n", - " for _ in range(1):\n", - " for use_local_transformer_for_inference in [False, True]:\n", - " for apply_prior in [False]:\n", - " predicted_audio, predicted_audio_lens, _, _, rtf_metrics, cross_attn_np, all_heads_attn_np = model.infer_batch(\n", - " batch_cuda, \n", - " max_decoder_steps=430, \n", - " temperature=0.6, \n", - " topk=80, \n", - " use_cfg=False,\n", - " cfg_scale=2.5,\n", - " prior_epsilon=0.1,\n", - " lookahead_window_size=5,\n", - " return_cross_attn_probs=True,\n", - " estimate_alignment_from_layers=[5],\n", - " apply_attention_prior=apply_prior,\n", - " apply_prior_to_layers=[0,1,2,3,4,5,6,7,8,9,10,11],\n", - " compute_all_heads_attn_maps=True,\n", - " start_prior_after_n_audio_steps=0,\n", - " use_local_transformer_for_inference=use_local_transformer_for_inference\n", - " )\n", - " print(\"generation time\", time.time() - st)\n", - " pprint.pprint(rtf_metrics)\n", - " for idx in range(predicted_audio.size(0)):\n", - " predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()\n", - " predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]\n", - " audio_path = os.path.join(out_dir, f\"predicted_audio_{item_idx}.wav\")\n", - " sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)\n", - " print(test_entries[bidx]['text'])\n", - " print(\"Prior Used?\", apply_prior)\n", - " print(\"use_local_transformer\", use_local_transformer_for_inference)\n", - " display(Audio(audio_path))\n", - " item_idx += 1\n", - " plt.imshow(cross_attn_np[idx])\n", - " plt.show()\n", - "# for hidx, head_cross_attn in enumerate(all_heads_attn_np[idx]):\n", - "# layer_num = hidx // model.cfg.t5_decoder.xa_n_heads\n", - "# head_num = hidx % model.cfg.t5_decoder.xa_n_heads\n", - "# print(\"item, layer, head\", idx, layer_num, head_num)\n", - "# plt.imshow(all_heads_attn_np[idx][hidx])\n", - "# plt.show()\n", - "\n", - " print(\"------------------------------------\")\n", - " print(\"------------------------------------\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7adb9800", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b20dab52", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/t5tts_inference_multiturndialogues.ipynb b/t5tts_inference_multiturndialogues.ipynb deleted file mode 100644 index 20f1fa7ff645..000000000000 --- a/t5tts_inference_multiturndialogues.ipynb +++ /dev/null @@ -1,320 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "466ccdc5", - "metadata": {}, - "outputs": [], - "source": [ - "from nemo.collections.tts.models import T5TTS_Model\n", - "from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset, DatasetSample\n", - "from omegaconf.omegaconf import OmegaConf, open_dict\n", - "import torch\n", - "import os\n", - "import soundfile as sf\n", - "from IPython.display import display, Audio\n", - "import os\n", - "import numpy as np\n", - "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "6659ae78", - "metadata": {}, - "source": [ - "## Set checkpoint and other file paths on your machine" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "04445f11", - "metadata": {}, - "outputs": [], - "source": [ - "hparams_file = \"/datap/misc/duplexcheckpoints/blackwell/duplex_blackwell_medium_decoder_noexpresso_onlyphonemeFT_hparams.yaml\"\n", - "checkpoint_file = \"/datap/misc/duplexcheckpoints/blackwell/duplex_blackwell_medium_decoder_withTC_fromroycheckpoint_lowsestvalloss.ckpt\"\n", - "codecmodel_path = \"/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo\"\n", - "out_dir = \"/datap/misc/t5tts_inference_notebook_samples\"\n", - "if not os.path.exists(out_dir):\n", - " os.makedirs(out_dir)\n", - "\n", - "dummy_audio_filepath = os.path.join(out_dir, \"dummy_audio.wav\")\n", - "sf.write(dummy_audio_filepath, np.zeros(22050 * 3), 22050)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "01e24b14", - "metadata": {}, - "source": [ - "## Load Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "87bf66f9", - "metadata": {}, - "outputs": [], - "source": [ - "model_cfg = OmegaConf.load(hparams_file).cfg\n", - "\n", - "with open_dict(model_cfg):\n", - " model_cfg.codecmodel_path = codecmodel_path\n", - " if hasattr(model_cfg, 'text_tokenizer'):\n", - " # Backward compatibility for models trained with absolute paths in text_tokenizer\n", - " model_cfg.text_tokenizer.g2p.phoneme_dict = \"scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt\"\n", - " model_cfg.text_tokenizer.g2p.heteronyms = \"scripts/tts_dataset_files/heteronyms-052722\"\n", - " model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0\n", - " model_cfg.train_ds = None\n", - " model_cfg.validation_ds = None\n", - "\n", - "\n", - "model = T5TTS_Model(cfg=model_cfg)\n", - "# Load weights from checkpoint file\n", - "print(\"Loading weights from checkpoint\")\n", - "ckpt = torch.load(checkpoint_file)\n", - "model.load_state_dict(ckpt['state_dict'])\n", - "print(\"Loaded weights.\")\n", - "\n", - "model.use_kv_cache_for_inference = True\n", - "\n", - "model.cuda()\n", - "model.eval()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "843167ec", - "metadata": {}, - "outputs": [], - "source": [ - "test_dataset = T5TTSDataset(\n", - " dataset_meta={},\n", - " sample_rate=model_cfg.sample_rate,\n", - " min_duration=0.5,\n", - " max_duration=20,\n", - " codec_model_downsample_factor=model_cfg.codec_model_downsample_factor,\n", - " bos_id=model.bos_id,\n", - " eos_id=model.eos_id,\n", - " context_audio_bos_id=model.context_audio_bos_id,\n", - " context_audio_eos_id=model.context_audio_eos_id,\n", - " audio_bos_id=model.audio_bos_id,\n", - " audio_eos_id=model.audio_eos_id,\n", - " num_audio_codebooks=model_cfg.num_audio_codebooks,\n", - " prior_scaling_factor=None,\n", - " load_cached_codes_if_available=True,\n", - " dataset_type='test',\n", - " tokenizer_config=None,\n", - " load_16khz_audio=model.model_type == 'single_encoder_sv_tts',\n", - " use_text_conditioning_tokenizer=model.use_text_conditioning_encoder,\n", - " pad_context_text_to_max_duration=model.pad_context_text_to_max_duration,\n", - " context_duration_min=model.cfg.get('context_duration_min', 5.0),\n", - " context_duration_max=model.cfg.get('context_duration_max', 5.0),\n", - ")\n", - "test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test')" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "cbd78348", - "metadata": {}, - "source": [ - "### Set dialogues to generate\n", - "\n", - "Each item in the list can be a single-turn or a multi-turn dialogue.\n", - "\n", - "[SPK-BWL-B-F] is the speaker tag for female speaker and [SPK-BWL-B-M] is the speaker tag for male speaker.\n", - "\n", - "ChatGPT prompt that can generate something like this:\n", - "\n", - "```\n", - "Generate dialogues for a 30 second podcast about .\n", - "The conversation should be between a male and female speaker formatted as follows:\n", - "[SPK-BWL-B-F] Sentence by a female speaker\n", - "[SPK-BWL-B-M] Sentence by a male speaker\n", - "where [SPK-BWL-B-F] and [SPK-BWL-B-M] indicate speaker tags. Dont have any quotation marks in the text, and if there are any numbers spell them out. Basically, keep the text normalized suitable for a TTS model. Keep the conversation fun and engaging with the speakers talking and responding to each other. \n", - "```\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f5e32a27", - "metadata": {}, - "outputs": [], - "source": [ - "dialogues = [\"[SPK-BWL-B-F] Have you been keeping up with multi-turn TTS models? They are so fascinating [SPK-BWL-B-M] Absolutely. The way they handle conversations feels almost human.\",\n", - "\"[SPK-BWL-B-F] I know. It is amazing how they can remember context across multiple exchanges [SPK-BWL-B-M] Exactly. It is not just about saying words anymore. They actually make the responses flow naturally.\",\n", - "\"[SPK-BWL-B-F] And it is not just for assistants or chatbots. I heard they are being used for things like audiobooks and interactive storytelling.\",\n", - "\"[SPK-BWL-B-M] That is true. Plus, they can even adjust tone to match emotions. Imagine hearing a story where the narrator sounds genuinely excited during a twist.\",\n", - "\"[SPK-BWL-B-F] Or even a bit sarcastic during a funny moment. That makes everything feel more real.\",\n", - "\"[SPK-BWL-B-M] For sure. It is like conversations with a machine are finally catching up to the way we actually talk.\",\n", - "\"[SPK-BWL-B-F] The future of tech keeps surprising me. Every day, there is something new to explore.\",]\n", - "dialogues = [d for d in dialogues if len(d) > 0]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "aa182540", - "metadata": {}, - "source": [ - "### Generation\n", - "\n", - "Below code generates 4 samples for each item in the dialogues list. Then asks, which one you like the best (index 0,1,2 or 3) and adds that to the already generated dialogue. You may modify the below code to automate and just select the first generation if you dont want to do this manually. After every dialogue it also plays the combined dialogues until that item." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "168f157d", - "metadata": {}, - "outputs": [], - "source": [ - "from pydub import AudioSegment\n", - "\n", - "def play_combined_audio(audio_files):\n", - " combined_audio = AudioSegment.empty()\n", - " for file in audio_files:\n", - " audio = AudioSegment.from_file(file)\n", - " combined_audio += audio\n", - " output_path = os.path.join(out_dir, \"combined_audio.wav\")\n", - " combined_audio.export(output_path, format=\"wav\")\n", - " display(Audio(output_path))\n", - " return output_path\n", - " \n", - "\n", - "\n", - "context_path = None\n", - "generated_audios = []\n", - "for didx, dialogue in enumerate(dialogues):\n", - " audio_dir = \"/\"\n", - " entry = {\n", - " \"audio_filepath\": dummy_audio_filepath,\n", - " \"duration\": 3.0,\n", - " \"text\": dialogue,\n", - " \"speaker\": \"dummy\",\n", - " }\n", - " if didx == 0:\n", - " entry[\"context_text\"] = \"MIXED SPEECH TTS\"\n", - " else:\n", - " entry['context_audio_filepath'] = context_path\n", - " entry['context_audio_duration'] = 5.0\n", - " \n", - " \n", - " \n", - " data_sample = DatasetSample(\n", - " dataset_name=\"sample\",\n", - " manifest_entry=entry,\n", - " audio_dir=audio_dir,\n", - " feature_dir=audio_dir,\n", - " text=entry['text'],\n", - " speaker=None,\n", - " speaker_index=0,\n", - " tokenizer_names=[\"english_phoneme\"]\n", - " )\n", - " test_dataset.data_samples = [data_sample]\n", - "\n", - " test_data_loader = torch.utils.data.DataLoader(\n", - " test_dataset,\n", - " batch_size=1,\n", - " collate_fn=test_dataset.collate_fn,\n", - " num_workers=0,\n", - " shuffle=False\n", - " )\n", - " \n", - " \n", - " item_idx = 0\n", - " for bidx, batch in enumerate(test_data_loader):\n", - " print(\"Processing batch {} out of {}\".format(bidx, len(test_data_loader)))\n", - " model.t5_decoder.reset_cache(use_cache=True)\n", - " batch_cuda ={}\n", - " for key in batch:\n", - " if isinstance(batch[key], torch.Tensor):\n", - " batch_cuda[key] = batch[key].cuda()\n", - " else:\n", - " batch_cuda[key] = batch[key]\n", - " import time\n", - " \n", - " candidates = []\n", - " for try_idx in range(4):\n", - " st = time.time()\n", - " predicted_audio, predicted_audio_lens, _, _ = model.infer_batch(\n", - " batch_cuda, \n", - " max_decoder_steps=500, \n", - " temperature=0.6, \n", - " topk=80, \n", - " use_cfg=True, \n", - " cfg_scale=1.6\n", - " )\n", - " print(\"generation time\", time.time() - st)\n", - " for idx in range(predicted_audio.size(0)):\n", - " predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()\n", - " predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]\n", - " audio_path = os.path.join(out_dir, f\"predicted_audio_{try_idx}_{didx}_{item_idx}.wav\")\n", - " sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)\n", - " print(\"Dialogue:\", item_idx, \"Candidate: \", try_idx)\n", - " display(Audio(audio_path))\n", - " candidates.append(audio_path)\n", - " \n", - " user_input = input(\"Enter Candidate number that sounds the best:\").strip().lower()\n", - " selected_audio_idx = int(user_input)\n", - " audio_path = candidates[selected_audio_idx]\n", - " item_idx += 1\n", - " generated_audios.append(audio_path)\n", - " print(\"Podcast generated until now:\")\n", - " combined_audio_path = play_combined_audio(generated_audios)\n", - " last_gen_audio = AudioSegment.from_file(combined_audio_path)\n", - " last_5_seconds = last_gen_audio[-5000:] # Duration is in milliseconds\n", - " context_path = os.path.join(out_dir, \"context.wav\")\n", - " last_5_seconds.export(context_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1368c380", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "943a791f", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 2338adb30e7694bdd4d1be06d6013a8b534a5749 Mon Sep 17 00:00:00 2001 From: blisc Date: Mon, 3 Nov 2025 20:36:58 +0000 Subject: [PATCH 104/113] Apply isort and black reformatting Signed-off-by: blisc --- .../tokenizers/text_to_speech/tts_tokenizers.py | 15 +++++++++++++-- .../tts/parts/preprocessing/features.py | 5 ++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index 567dfc6b863b..bb6b362652d1 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -51,6 +51,7 @@ class BaseTokenizer(ABC): add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), if None then no blank in labels. """ + PAD, BLANK, OOV = '', '', '' def __init__(self, tokens, *, pad=PAD, blank=BLANK, oov=OOV, sep='', add_blank_at=None): @@ -106,6 +107,7 @@ class BaseCharsTokenizer(BaseTokenizer): non_default_punct_list: List of punctuation marks which will be used instead default. text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. """ + # fmt: off # TODO @xueyang: unify definition of the default PUNCT_LIST and import from ipa_lexicon.py PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally @@ -186,6 +188,7 @@ class EnglishCharsTokenizer(BaseCharsTokenizer): text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. Basically, it replaces all non-unicode characters with unicode ones and apply lower() function. """ + def __init__( self, punct=True, @@ -359,6 +362,7 @@ class ItalianCharsTokenizer(BaseCharsTokenizer): pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. non_default_punct_list: List of punctuation marks which will be used instead default. """ + PUNCT_LIST = get_ipa_punctuation_list("it-IT") def __init__( @@ -389,6 +393,7 @@ class GermanPhonemesTokenizer(BaseCharsTokenizer): text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. Currently, it only applies lower() function. """ + # fmt: off PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally ',', '.', '!', '?', '-', @@ -460,6 +465,7 @@ class ItalianPhonemesTokenizer(BaseCharsTokenizer): text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. Currently, it only applies lower() function. """ + # fmt: off PUNCT_LIST = ( ',', '.', '!', '?', '-', @@ -544,6 +550,7 @@ class EnglishPhonemesTokenizer(BaseTokenizer): Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be handled by g2p). """ + # fmt: off PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally ',', '.', '!', '?', '-', @@ -707,6 +714,7 @@ class IPATokenizer(BaseTokenizer): if None then no blank in labels. pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. """ + def __init__( self, g2p, @@ -876,6 +884,7 @@ class ChinesePhonemesTokenizer(BaseTokenizer): Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be handled by g2p). """ + # fmt: off PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally ',', '.', '!', '?', '-', @@ -952,8 +961,9 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): if p == space and len(ps) > 0 and ps[-1] != space: ps.append(p) # Add next phoneme or tone or ascii letter or apostrophe. - elif ((p.isalnum() or p == "'" or p in self.phoneme_list + self.tone_list + self.ascii_letter_list) and - p in tokens): + elif ( + p.isalnum() or p == "'" or p in self.phoneme_list + self.tone_list + self.ascii_letter_list + ) and p in tokens: ps.append(p) # Add punctuation elif (p in self.PUNCT_LIST) and self.punct: @@ -1095,6 +1105,7 @@ class AggregatedTTSTokenizer: tokenizers: List of tokenizers to aggregate. tokenizer_names: List of names for each tokenizer (usually the language identifier). """ + def __init__(self, tokenizers: List[Union[BaseTokenizer, PreTrainedTokenizerBase]], tokenizer_names: List[str]): assert len(tokenizers) == len(tokenizer_names), "Number of tokenizers and tokenizer names must be the same." tokens = [] diff --git a/nemo/collections/tts/parts/preprocessing/features.py b/nemo/collections/tts/parts/preprocessing/features.py index 307f54312214..1e8a7c9ec995 100644 --- a/nemo/collections/tts/parts/preprocessing/features.py +++ b/nemo/collections/tts/parts/preprocessing/features.py @@ -77,7 +77,10 @@ def _get_feature_filepath( def _features_exists( - feature_names: List[Optional[str]], manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path, + feature_names: List[Optional[str]], + manifest_entry: Dict[str, Any], + audio_dir: Path, + feature_dir: Path, ) -> bool: for feature_name in feature_names: if feature_name is None: From c78bc759028ff71ee2c76dc472b09845566bce90 Mon Sep 17 00:00:00 2001 From: Jason Date: Mon, 3 Nov 2025 12:40:05 -0800 Subject: [PATCH 105/113] remove experimental imports Signed-off-by: Jason --- .../common/tokenizers/text_to_speech/tts_tokenizers.py | 1 - nemo/collections/tts/data/text_to_speech_dataset.py | 1 - nemo/collections/tts/data/vocoder_dataset.py | 1 - nemo/collections/tts/g2p/models/i18n_ipa.py | 1 - nemo/collections/tts/models/audio_codec.py | 1 - nemo/collections/tts/modules/encodec_modules.py | 1 - nemo/collections/tts/parts/preprocessing/feature_processors.py | 2 -- nemo/collections/tts/parts/preprocessing/features.py | 1 - nemo/collections/tts/parts/utils/callbacks.py | 1 - 9 files changed, 10 deletions(-) diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index 567dfc6b863b..428899a78f10 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -37,7 +37,6 @@ vietnamese_text_preprocessing, ) from nemo.utils import logging -from nemo.utils.decorators import experimental class BaseTokenizer(ABC): diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 29fbe9646848..9d1ec7b3e958 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -37,7 +37,6 @@ ) from nemo.core.classes import Dataset from nemo.utils import logging -from nemo.utils.decorators import experimental @dataclass diff --git a/nemo/collections/tts/data/vocoder_dataset.py b/nemo/collections/tts/data/vocoder_dataset.py index b6ea56e2ed7f..84be82f47959 100644 --- a/nemo/collections/tts/data/vocoder_dataset.py +++ b/nemo/collections/tts/data/vocoder_dataset.py @@ -36,7 +36,6 @@ from nemo.core.classes import Dataset, IterableDataset from nemo.utils import logging from nemo.utils import webdataset as wds -from nemo.utils.decorators import experimental from nemo.utils.distributed import webdataset_split_by_workers VALID_FILE_FORMATS = ';'.join(['wav', 'mp3', 'flac', 'opus'] + [fmt.lower() for fmt in valid_sf_formats.keys()]) diff --git a/nemo/collections/tts/g2p/models/i18n_ipa.py b/nemo/collections/tts/g2p/models/i18n_ipa.py index eb254bfab9f8..6a2927db4be4 100644 --- a/nemo/collections/tts/g2p/models/i18n_ipa.py +++ b/nemo/collections/tts/g2p/models/i18n_ipa.py @@ -28,7 +28,6 @@ from nemo.collections.tts.g2p.models.base import BaseG2p from nemo.collections.tts.g2p.utils import GRAPHEME_CASE_MIXED, GRAPHEME_CASE_UPPER, set_grapheme_case from nemo.utils import logging -from nemo.utils.decorators import experimental class IpaG2p(BaseG2p): diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 1601350457cf..4f4ca21edb30 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -42,7 +42,6 @@ from nemo.core.neural_types.neural_type import NeuralType from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler from nemo.utils import logging, model_utils -from nemo.utils.decorators import experimental try: import torchaudio diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index 34c35bb105d0..db32566a536e 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -59,7 +59,6 @@ from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType, LossType, VoidType from nemo.core.neural_types.neural_type import NeuralType from nemo.utils import logging -from nemo.utils.decorators import experimental class SEANetResnetBlock(NeuralModule): diff --git a/nemo/collections/tts/parts/preprocessing/feature_processors.py b/nemo/collections/tts/parts/preprocessing/feature_processors.py index afb0fffa5928..ccbc2057101f 100644 --- a/nemo/collections/tts/parts/preprocessing/feature_processors.py +++ b/nemo/collections/tts/parts/preprocessing/feature_processors.py @@ -19,8 +19,6 @@ import torch -from nemo.utils.decorators import experimental - class FeatureProcessor(ABC): @abstractmethod diff --git a/nemo/collections/tts/parts/preprocessing/features.py b/nemo/collections/tts/parts/preprocessing/features.py index 307f54312214..32ef095c635c 100644 --- a/nemo/collections/tts/parts/preprocessing/features.py +++ b/nemo/collections/tts/parts/preprocessing/features.py @@ -25,7 +25,6 @@ from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor from nemo.collections.tts.parts.utils.tts_dataset_utils import get_audio_filepaths, normalize_volume, stack_tensors -from nemo.utils.decorators import experimental class Featurizer(ABC): diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index 366ff864ffea..f0d2dc363237 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -31,7 +31,6 @@ from nemo.collections.tts.parts.utils.helpers import create_plot from nemo.utils import logging -from nemo.utils.decorators import experimental HAVE_WANDB = True try: From 8e581a3502919ae98449f2486fe517a13402fa4a Mon Sep 17 00:00:00 2001 From: Jason Date: Mon, 3 Nov 2025 12:42:14 -0800 Subject: [PATCH 106/113] more flake8 Signed-off-by: Jason --- .../common/tokenizers/text_to_speech/tts_tokenizers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index 1c64ccc505a4..9cd20cfc83b5 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -680,6 +680,7 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): @contextmanager def set_phone_prob(self, prob): + """Updates the phone probability inside context""" if hasattr(self.g2p, "phoneme_probability"): self.g2p.phoneme_probability = prob try: @@ -854,6 +855,7 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None) - @contextmanager def set_phone_prob(self, prob): + """Updates the phone probability inside context""" if hasattr(self.g2p, "phoneme_probability"): self.g2p.phoneme_probability = prob try: @@ -1147,10 +1149,12 @@ def __init__(self, tokenizers: List[Union[BaseTokenizer, PreTrainedTokenizerBase raise ValueError("AggregatedTTSTokenizer could not find a padding token in the first tokenizer") def encode(self, text: str, tokenizer_name: str = None) -> List[int]: + """Tokenizer encode from text to tokens""" tokenizer = self.tokenizers[tokenizer_name] tokens = tokenizer.encode(text) return [self.tokenizer_offsets[tokenizer_name] + token for token in tokens] def decode(self, tokens: List[int], tokenizer_name: str = None) -> str: + """Tokernizer decoder from tokens to text""" tokenizer = self.tokenizers[tokenizer_name] return tokenizer.decode([token - self.tokenizer_offsets[tokenizer_name] for token in tokens]) From 4920950cd9500f07bf5f411edd73959d65cd6fb9 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 4 Nov 2025 07:09:31 -0800 Subject: [PATCH 107/113] Add in fix from #15028 by @matteolippi Signed-off-by: Jason --- nemo/collections/tts/data/text_to_speech_dataset.py | 4 ++-- .../tts/models/magpietts_preference_optimization.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 9d1ec7b3e958..aace2f198f26 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -193,8 +193,8 @@ def _preprocess_manifest( sample = DatasetSample( dataset_name=dataset_name, manifest_entry=entry, - audio_dir=dataset.audio_dir, - feature_dir=dataset.feature_dir, + audio_dir=Path(dataset.audio_dir), + feature_dir=Path(dataset.feature_dir), text=text, speaker=speaker, speaker_index=speaker_index, diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index 84e845cec2f5..cf59fcafab8b 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -1051,8 +1051,8 @@ def process_text_for_cer(input_text): single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) # @shehzeen: Added this to handle some common errors in ASR transcripts - single_space_text.replace("h t t p", "http") - single_space_text.replace("w w w", "www") + single_space_text = single_space_text.replace("h t t p", "http") + single_space_text = single_space_text.replace("w w w", "www") return single_space_text From 9b99d2d4e9659240848a5a96c3824d7243fee07b Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 4 Nov 2025 07:37:53 -0800 Subject: [PATCH 108/113] attempt to address some codeQL issues Signed-off-by: Jason --- .../tts/models/magpietts_preference_optimization.py | 7 +++++++ scripts/magpietts/evaluate_generated_audio.py | 2 ++ scripts/magpietts/infer_and_evaluate.py | 10 +++++++--- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index cf59fcafab8b..1001d842da46 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -663,6 +663,12 @@ def generate_and_reward( ) pred_transcripts.append(transcript) pred_transcripts = [process_text_for_cer(transcript) for transcript in pred_transcripts] + else: + # Address CodeQL issue where pred_transcripts might be undefined for future code + raise ValueError( + f"{self} received a value of {self.cfg.get("reward_asr_model", "nemo")} in cfg.reward_asr_model " + "but this class only supports 'nemo' or 'whisper'." + ) pred_speaker_embeddings = get_speaker_embeddings_from_filepaths( predicted_audio_paths, self.eval_speaker_verification_model, self.device @@ -888,6 +894,7 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): policy_model_outputs = self.process_batch(batch_repeated) + reference_model_output = None # Address CodeQL issue even though this varibable is only used not self.reference_free if not self.reference_free: with torch.no_grad(): reference_model_output = self._reference_model.process_batch(batch_repeated) diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index e780c7108cf6..20d93dcad3a4 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -199,6 +199,8 @@ def evaluate( device = "cuda" + whisper_processor = None # Address CodeQL issue even though this varibable is only used when language != "en" + utmosv2_scores = None # Address CodeQL issue even though this varibable is only used when with_utmosv2 is true if language == "en": if asr_model_name.startswith("nvidia/") or asr_model_name in ["stt_en_conformer_transducer_large"]: asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=asr_model_name) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 54b1072b6893..cbc0c452a4c7 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -292,11 +292,14 @@ def run_inference( hparams_file_from_wandb=False, log_exp_name=False, compute_fcd=False, - violin_plot_metrics=['cer', 'pred_context_ssim'], + violin_plot_metrics=None, eos_detection_method=None, ignore_finished_sentence_tracking=False, with_utmosv2=True, ): + # Avoid lists as default values and apply default value in function + if violin_plot_metrics is None: + violin_plot_metrics = ['cer', 'pred_context_ssim', 'utmosv2'] # Load model if hparams_file is not None and checkpoint_file is not None: model_cfg = OmegaConf.load(hparams_file) @@ -754,6 +757,7 @@ def main(): with_utmosv2=not args.disable_utmosv2, ) + cer, ssim = None, None # Mode 1: Run inference from provided hparams and checkpoint files if ( (args.hparams_files is not None) @@ -789,9 +793,9 @@ def main(): "1. --hparams_files and --checkpoint_files\n" "2. --nemo_file\n" ) - if args.cer_target is not None and cer > float(args.cer_target): + if cer is not None and args.cer_target is not None and cer > float(args.cer_target): raise ValueError() - if args.ssim_target is not None and ssim < float(args.ssim_target): + if ssim is not None and args.ssim_target is not None and ssim < float(args.ssim_target): raise ValueError() From 0f4045a86db06594f27daea9ae17a56440986000 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 4 Nov 2025 07:39:59 -0800 Subject: [PATCH 109/113] fix typo Signed-off-by: Jason --- .../collections/tts/models/magpietts_preference_optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index 1001d842da46..cd380c05ee0b 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -666,7 +666,7 @@ def generate_and_reward( else: # Address CodeQL issue where pred_transcripts might be undefined for future code raise ValueError( - f"{self} received a value of {self.cfg.get("reward_asr_model", "nemo")} in cfg.reward_asr_model " + f"{self} received a value of {self.cfg.get('reward_asr_model', 'nemo')} in cfg.reward_asr_model " "but this class only supports 'nemo' or 'whisper'." ) From 311a6fcfcc459a995ffb228826f59be4782af9d3 Mon Sep 17 00:00:00 2001 From: blisc Date: Tue, 4 Nov 2025 15:40:57 +0000 Subject: [PATCH 110/113] Apply isort and black reformatting Signed-off-by: blisc --- .../tts/models/magpietts_preference_optimization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index cd380c05ee0b..e2506d08e497 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -894,7 +894,9 @@ def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): policy_model_outputs = self.process_batch(batch_repeated) - reference_model_output = None # Address CodeQL issue even though this varibable is only used not self.reference_free + reference_model_output = ( + None # Address CodeQL issue even though this varibable is only used not self.reference_free + ) if not self.reference_free: with torch.no_grad(): reference_model_output = self._reference_model.process_batch(batch_repeated) From 79b0ecaeb33cf155fe6711a6ddfab939696ba14d Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 4 Nov 2025 09:06:39 -0800 Subject: [PATCH 111/113] Remove single_encoder_sv_tts and decoder_pretrain_synthesizer model types from magpie and Clean up scripts Signed-off-by: Jason --- nemo/collections/tts/models/magpietts.py | 136 +-- scripts/magpietts/README.md | 46 - scripts/magpietts/README_lhotse.md | 349 ------- scripts/magpietts/README_magpie_po.md | 87 +- scripts/magpietts/codec_extraction.py | 272 ----- .../create_text_context_lhotse_manifest.py | 224 ----- scripts/magpietts/eval_squimmos.py | 86 -- scripts/magpietts/evalset_config.py | 3 + scripts/magpietts/evaluate_generated_audio.py | 3 + ...extend_nemo_manifest_with_context_audio.py | 930 ------------------ scripts/magpietts/infer_and_evaluate.py | 5 + scripts/tts_dataset_to_lhotse/README.md | 82 -- scripts/tts_dataset_to_lhotse/create_shars.py | 163 --- 13 files changed, 98 insertions(+), 2288 deletions(-) delete mode 100644 scripts/magpietts/README.md delete mode 100644 scripts/magpietts/README_lhotse.md delete mode 100644 scripts/magpietts/codec_extraction.py delete mode 100644 scripts/magpietts/create_text_context_lhotse_manifest.py delete mode 100644 scripts/magpietts/eval_squimmos.py delete mode 100644 scripts/magpietts/extend_nemo_manifest_with_context_audio.py delete mode 100644 scripts/tts_dataset_to_lhotse/README.md delete mode 100644 scripts/tts_dataset_to_lhotse/create_shars.py diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 44d89cdda843..942e639a641c 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -71,10 +71,6 @@ class MagpieTTSModel(ModelPT): Supports multiple model types: - - single_encoder_sv_tts: Transcript goes into the encoder and target audio goes to the decoder. Additionally, - speaker_embedding of target audio (or context audio if provided) from TitaNet gets added to encoder - output(all timesteps). - - multi_encoder_context_tts: Transcript and context audio go to different encoders. Transcript encoding feeds to layers given by cfg.model.transcript_decoder_layers and the context encoding feeds into the layers given by context_decoder_layers .Also supports text context which gets encoded by the same encoder as context audio. @@ -85,8 +81,8 @@ class MagpieTTSModel(ModelPT): value (5 seconds). Text context, which is usually shorter than number of codec frames of 5 second of audio, is padded to the max context duration in this model. - - decoder_pretrain_synthesizer: This is the model type used for pretraining the decoder only on audio data using - next frame prediction loss. + - decoder_ce: Same as decoder_context_tts except there is a small neural network between the context tensors and + the decoder input. """ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): @@ -203,7 +199,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.eos_id = num_tokens - 1 self.model_type = cfg.get('model_type', None) - self.pad_context_text_to_max_duration = self.model_type in ['decoder_context_tts', 'decoder_ce'] self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False) @@ -237,33 +232,29 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): audio_embeddings.append(nn.Embedding(self.num_all_tokens_per_codebook, cfg.embedding_dim)) self.audio_embeddings = nn.ModuleList(audio_embeddings) - if self.model_type != 'decoder_pretrain_synthesizer': - # Decoder pretrain synthesizer doesn't have transcript encoder/text embeddings - - if self.use_bpe_char_tokenizer: - # BPE char tokenizer - assert len(self.tokenizer.tokenizers) == 1, "BPE char tokenizer should only be used with one tokenizer" - tokenizer_name = self.tokenizer.tokenizer_names[0] - tokenizer = self.tokenizer.tokenizers[tokenizer_name] - subword_vocab = tokenizer.get_vocab() - # special tokens will be stored as it is in the char_vocab - # Each special token will only be mapped to one char id - special_vocab = { - '': self.bos_id, - '': self.eos_id, - } - self.cas_encoder = CharAwareSubwordEncoder( - d_embed=cfg.embedding_dim, - llm_tokenizer_vocab=subword_vocab, - subword_padding_idx=self.tokenizer.pad, - special_vocab=special_vocab, - ) - else: - # Regular text embedding - self.text_embedding = nn.Embedding(num_tokens, cfg.embedding_dim) - - self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) + if self.use_bpe_char_tokenizer: + # BPE char tokenizer + assert len(self.tokenizer.tokenizers) == 1, "BPE char tokenizer should only be used with one tokenizer" + tokenizer_name = self.tokenizer.tokenizer_names[0] + tokenizer = self.tokenizer.tokenizers[tokenizer_name] + subword_vocab = tokenizer.get_vocab() + # special tokens will be stored as it is in the char_vocab + # Each special token will only be mapped to one char id + special_vocab = { + '': self.bos_id, + '': self.eos_id, + } + self.cas_encoder = CharAwareSubwordEncoder( + d_embed=cfg.embedding_dim, + llm_tokenizer_vocab=subword_vocab, + subword_padding_idx=self.tokenizer.pad, + special_vocab=special_vocab, + ) + else: + # Regular text embedding + self.text_embedding = nn.Embedding(num_tokens, cfg.embedding_dim) + self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) self.final_proj = nn.Linear( cfg.decoder.d_model, @@ -304,18 +295,9 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): temperature=15.0, ) - if self.model_type == 'single_encoder_sv_tts': - # Context audio goes through Titanet to get speaker embedding - # Speaker embedding is added to the transcript encoder output - self._speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( - model_name='titanet_large' - ) - self._speaker_verification_model.freeze() # Lightning does requires_grad = False and self.eval() - self.speaker_projection_layer = nn.Linear(cfg.speaker_emb_dim, cfg.embedding_dim) - self.transcript_decoder_layers = [ - idx for idx in range(self.decoder.n_layers) - ] # All layers are used for text - elif self.model_type == 'multi_encoder_context_tts': + if self.model_type == 'multi_encoder_context_tts': + logging.warning(f"The multi_encoder_context_tts model type for {self} is deprecated.") + # Transcript and context audio/text go to different encoders. # Output of the encoders goes to the decoder through the cross-attention layers self.transcript_decoder_layers = cfg.get('transcript_decoder_layers', [3, 4, 5, 6, 7, 8]) @@ -341,10 +323,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.transcript_decoder_layers = [ idx for idx in range(cfg.decoder.n_layers) ] # All layers are used for text - - elif self.model_type == 'decoder_pretrain_synthesizer': - # This is for pretraining the decoder only on audio data using next frame prediction loss - assert cfg.alignment_loss_scale == 0.0, "Alignment loss is not supported for decoder pretrain synthesizer" else: raise ValueError(f"Unsupported model type {self.model_type}") @@ -385,6 +363,9 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): """ Only used for saving checkpoints. On save, we remove _speaker_verification_model and _codec_model from the checkpoint. The codec model is saved in a separate checkpoint. + + _speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts + model_type that is no longer supported and can likely be removed in a future version. """ if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} @@ -439,6 +420,9 @@ def load_state_dict(self, state_dict, strict=True): strict is True. When strict is False, we can call pytorch's load_state_dict. When strict is True, we loop through all parameters and rename them to enable loading. + + _speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts + model_type that is no longer supported and can likely be removed in a future version. """ state_dict = self.update_ckpt(state_dict) if strict == False: @@ -533,16 +517,6 @@ def embed_audio_tokens(self, audio_tokens): audio_embedding = audio_embedding / (C * self.frame_stacking_factor) return audio_embedding - def get_speaker_embeddings(self, audio_16khz, audio_len_16khz): - # audio_16khz: (B, T) - # audio_len_16khz: (B,) - self._speaker_verification_model.eval() - with torch.no_grad(): - _, speaker_embeddings = self._speaker_verification_model.forward( - input_signal=audio_16khz, input_signal_length=audio_len_16khz - ) - return speaker_embeddings - def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_offset_by_one=False): """ Predicts the logits for all codebooks using the local transformer. Used in both autoregressive (AR) and MaskGit (MG) modes. @@ -624,8 +598,6 @@ def maskgit_create_random_mask(self, codes): frac_masked = cosine_schedule(rand_values) # how many positions to mask n_masked = torch.ceil(frac_masked * C).long() # B,T - # start from all unmasked - mask = torch.zeros_like(codes, dtype=torch.bool) # The code further below is the vectorized version of this: # for b in range(B): # for t in range(T): @@ -1416,27 +1388,16 @@ def prepare_context_tensors(self, batch): text = None text_lens = None - # self.model_type must be one of - # [single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts, decoder_ce, decoder_pretrain_synthesizer] - if self.model_type != 'decoder_pretrain_synthesizer': - text = batch['text'] - text_lens = batch['text_lens'] - text_mask = get_mask_from_lengths(text_lens) # (B, T) - text_embedded = self.embed_text(text, text_mask) # (B, T, E) - text_encoder_out = self.encoder(text_embedded, text_mask, cond=None, cond_mask=None)['output'] # (B, T, E) - _attn_prior = batch.get('align_prior_matrix', None) - _attn_prior = self.scale_prior(_attn_prior, self.global_step) - - if self.model_type == 'single_encoder_sv_tts': - target_audio_16khz = batch['audio_16khz'] - target_audio_lens_16khz = batch['audio_lens_16khz'] - speaker_embeddings = self.get_speaker_embeddings(target_audio_16khz, target_audio_lens_16khz) - speaker_embeddings_projected = self.speaker_projection_layer(speaker_embeddings) - cond = text_encoder_out + speaker_embeddings_projected.unsqueeze(1) - cond_mask = text_mask - multi_encoder_mapping = None - attn_prior = _attn_prior - elif self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts', 'decoder_ce']: + # self.model_type must be one of [multi_encoder_context_tts, decoder_context_tts, decoder_ce] + text = batch['text'] + text_lens = batch['text_lens'] + text_mask = get_mask_from_lengths(text_lens) # (B, T) + text_embedded = self.embed_text(text, text_mask) # (B, T, E) + text_encoder_out = self.encoder(text_embedded, text_mask, cond=None, cond_mask=None)['output'] # (B, T, E) + _attn_prior = batch.get('align_prior_matrix', None) + _attn_prior = self.scale_prior(_attn_prior, self.global_step) + + if self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts', 'decoder_ce']: if 'context_audio_codes' in batch: context_audio_codes = batch['context_audio_codes'] context_audio_codes_lens = batch['context_audio_codes_lens'] @@ -1519,8 +1480,6 @@ def prepare_context_tensors(self, batch): multi_encoder_mapping = None additional_decoder_input = context_embeddings additional_decoder_mask = context_mask - elif self.model_type == 'decoder_pretrain_synthesizer': - pass else: raise ValueError(f"Unsupported model type {self.model_type}") @@ -1971,10 +1930,7 @@ def validation_step(self, batch, batch_idx): ) # Get attention image data for logging - if ( - self.model_type != 'decoder_pretrain_synthesizer' - and len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1 - ): + if len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1: # cross_attn_probabilities only returned when not using flash attention cross_attention_probs = [ attn['cross_attn_probabilities'][0] @@ -2674,7 +2630,7 @@ def get_dataset(self, dataset_cfg, dataset_type): text_context_remapping=self.text_context_remapping, text_context_remapping_prob=self.text_context_remapping_prob, ) - dataset.load_16khz_audio = self.model_type == 'single_encoder_sv_tts' + dataset.load_16khz_audio = False dataset.tokenizer_config = ( self.cfg.text_tokenizers ) # This will be used in worker_init_fn for instantiating tokenizer @@ -2695,7 +2651,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D prior_scaling_factor=self.cfg.prior_scaling_factor, load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) - load_16khz_audio=(self.model_type == 'single_encoder_sv_tts'), + load_16khz_audio=False, pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, context_duration_min=self.cfg.context_duration_min, context_duration_max=self.cfg.context_duration_max, diff --git a/scripts/magpietts/README.md b/scripts/magpietts/README.md deleted file mode 100644 index ef99015e65c0..000000000000 --- a/scripts/magpietts/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# MagpieTTS Inference and Evaluation - -To evaluate any MagpieTTS checkpoint you trained follow the steps as shown below (INTERNAL ONLY): - -1) Mount the EOS cluster path `/lustre/fsw/llmservice_nemo_speechlm/data/TTS:/Data` - -All the needed manifests are here: `/lustre/fsw/llmservice_nemo_speechlm/data/TTS/evaluation_manifests` - -2) Run the following command: -``` -CKPT= -HPARAM= -CODEC= -OUT_DIR= - -python scripts/magpietts/infer_and_evaluate.py \ ---checkpoint_files ${CKPT} \ ---hparams_files ${HPARAM} \ ---codecmodel_path ${CODEC} \ ---out_dir ${OUT_DIR} \ ---use_cfg \ ---apply_attention_prior -``` - -**Test Sets** -The Datasets that we evaluate on are: - -- LibriTTS test clean -- LibriTTS seen -- VCTK -- RIVA Hard examples - -**Evaluation Metrics** - -- ASR of the generated speech is done using `nvidia/parakeet-tdt-1.1b` and then CER/WER is computed. -- Speaker Similarity using `titanet` - - - -# Using Lhotse Datasets in MagpieTTS - -Refer to [this file](./README_lhotse.md) for more information about using Lhotse Dataset of MagpieTTS. - -# Preference Alignment of MagpieTTS - -Refer to [this file](./README_magpie_po.md) for more information about preference alignment of MagpieTTS. \ No newline at end of file diff --git a/scripts/magpietts/README_lhotse.md b/scripts/magpietts/README_lhotse.md deleted file mode 100644 index 9a8eb7735d1b..000000000000 --- a/scripts/magpietts/README_lhotse.md +++ /dev/null @@ -1,349 +0,0 @@ -This guidance describes the new Lhotse Shar process for converting NeMo datasets to Lhotse Shar datasets, designed for -training and validating Magpie-TTS. This new version significantly reduces computation overhead by using rank-balanced -workloading and independent writing across parallel processes. It also separate the processes to CPU-only nodes and -GPU-only nodes accordingly. - -## Creating New Lhotse Shar Data - -The process involves four main steps: -1. **Prepare Input Manifests (on CPU nodes):** Standardize the input NeMo manifests for each dataset. -2. **Extend Manifests with Context Audio (on GPU nodes):** Enhance the NeMo manifests by adding context audio information. -3. **Create Lhotse Shards (on CPU nodes):** Convert the extended NeMo manifests into Lhotse shards. -4. **Extend Shards with Audio Codes (on GPU nodes):** Process the Lhotse shards to extract and include audio codes (audio codec extraction). - -### Step 1: Prepare Input Manifests - -This first step runs on **CPU nodes** and is responsible for standardizing the input NeMo manifests for each dataset. This may involve consolidating multiple input files or reformatting entries. It's a preparatory step to ensure the manifest is in the correct format for the subsequent stages. - -*Note: The actual implementation for this step ([`prep_input_manifest.py`](https://gitlab-master.nvidia.com/xueyang/nemo-tts-artifacts-registry/-/blob/model_release_2505/model_release_2505/data_prep/hifitts2/prep_input_manifest_iad.py) in the internal scripts) is highly specific to the dataset and environment. Users should create their own script to prepare a standardized manifest file as input for Step 2.* - -A crucial part of this step is to ensure the `speaker` field in the NeMo manifest conforms to the required format: -```python -def check_speaker_format(item: str): - # enforce the format as example like "| Language:en Dataset:HiFiTTS Speaker:9136_other |". - pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" - return bool(re.match(pattern, item)) -``` - -#### Checkout the Outputs of `hifitts2/prep_input_manifest.py` -```bash -$ tree -L 1 -P '*.json|*.txt' hifitts2/nemo_manifest/ -hifitts2/nemo_manifest/ -├── hifitts2_all_splits.json -├── hifitts2_dev_seen.json -├── hifitts2_dev_unseen.json -├── hifitts2_test_seen.json -├── hifitts2_test_unseen.json -├── hifitts2_train.json # This is the standardized NeMo manifest used for the following steps. -├── manifest_empty_normalized_text_fields.json -├── manifest_librosa_error.json -├── manifest_mismatched_audio_duration.json -├── manifest_missing_asr_metrics.json -├── manifest_missing_audio_files.json -├── manifest_missing_required_fields.json -└── stats.txt # This helps to understand the success and failure records. -``` - -### Step 2: Extend NeMo Manifest with Context Audio - -This step runs on **GPU nodes**. It enhances the standardized NeMo manifest from Step 1 by adding context audio information. - -Improvements over older recipes include: -- Speaker embedding extraction is run on the fly, using `torch.matmul` to compute a similarity matrix. -- It recursively finds the next best context audio if the top candidate is unsuitable, preserving more data. -- It is scaling-friendly by pre-allocating a distinct subset of speaker records to each GPU rank for balanced workloads using a greedy bin-packing strategy. -- Manifests are written out in a buffered way to reduce I/O calls. - -#### Example command: -```bash -# General setup -CODE_DIR="/workspace/NeMo" -export PYTHONPATH="${CODE_DIR}:${PYTHONPATH}" -cd ${CODE_DIR} - -# Script parameters -INPUT_MANIFEST="/path/to/hifitts2/nemo_manifest/hifitts2_train.json" # From Step 1 -AUDIO_BASE_DIR="/path/to/audio/files" -OUTPUT_DIR="/path/to/hifitts2/nemo_manifest" -DATASET_NAME="hifitts2" # e.g., hifitts, libritts, etc. Used for speaker ID parsing. -CONTEXT_MIN_DURATION=3.0 -CONTEXT_MIN_SSIM=0.6 -BATCH_SIZE=256 -FLUSH_THRESHOLD_ITEMS=256 -NUM_WORKERS=8 -DEVICES=-1 -NUM_NODES=1 -WANDB_ENTITY="xyz" -WANDB_PROJECT="xyz" -WANDB_NAME="xyz" - -echo "****** STEP 2: Extending NeMo Manifest with Context Audio ******" -python scripts/magpietts/extend_nemo_manifest_with_context_audio.py \ - --dataset-name ${DATASET_NAME} \ - --manifest ${INPUT_MANIFEST} \ - --audio-base-dir ${AUDIO_BASE_DIR} \ - --output-dir ${OUTPUT_DIR} \ - --flush-threshold-items ${FLUSH_THRESHOLD_ITEMS} \ - --context-min-duration ${CONTEXT_MIN_DURATION} \ - --context-min-ssim ${CONTEXT_MIN_SSIM} \ - --batch-size ${BATCH_SIZE} \ - --devices ${DEVICES} \ - --num-nodes ${NUM_NODES} \ - --num-workers ${NUM_WORKERS} \ - --wandb-entity ${WANDB_ENTITY} \ - --wandb-project ${WANDB_PROJECT} \ - --wandb-name ${WANDB_NAME} -``` - -#### Checkout the Outputs -```bash -$ tree -L 1 hifitts2/nemo_manifest/extend_nemo_manifest_with_context_audio/ -hifitts2/nemo_manifest/extend_nemo_manifest_with_context_audio/ -├── hifitts2_train_rank0.json -├── hifitts2_train_rank1.json -├── hifitts2_train_rank2.json -├── hifitts2_train_rank3.json -├── hifitts2_train_rank4.json -├── hifitts2_train_rank5.json -├── hifitts2_train_rank6.json -├── hifitts2_train_rank7.json -└── hifitts2_train_withContextAudioMinDur3.0MinSSIM0.6.json # This is the NeMo manifest used for the following steps. -``` - - -### Step 3: Create Lhotse Shards from NeMo Manifest - -This step runs on **CPU nodes**. It converts the extended NeMo manifests (from Step 2) into Lhotse shards. - -Key features: -- Processes chunks of manifest entries, loads audio, and writes corresponding shard files for cuts, target audio, and context audio. -- Designed to be run in parallel worker processes. -- Loads and writes audio iteratively to save memory. - -#### Example command: -```bash -# General setup -CODE_DIR="/workspace/NeMo" -export PYTHONPATH="${CODE_DIR}:${PYTHONPATH}" -cd ${CODE_DIR} - -# Script parameters -EXTENDED_MANIFEST_PATH="/path/to/hifitts2/nemo_manifest/extend_nemo_manifest_with_context_audio/hifitts2_train_withContextAudioMinDur3.0MinSSIM0.6.json" # From Step 2 -AUDIO_BASE_DIR="/path/to/audio/files" -SAVE_DIR="/path/to/lhotse_shar_output" -NUM_WORKERS=16 # Number of CPU cores -SHARD_SIZE=4096 - -echo "****** STEP 3: Creating Lhotse Shards from NeMo Manifest ******" -python scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py \ - --manifest-path ${EXTENDED_MANIFEST_PATH} \ - --audio-base-dir ${AUDIO_BASE_DIR} \ - --output-dir ${SAVE_DIR} \ - --num-jobs ${NUM_WORKERS} \ - --processing-chunk-size ${SHARD_SIZE} \ - --audio-format 'flac' \ - --log-level 'INFO' -``` - -#### Checkout the outpus - -```bash -$ tree -L 3 -P "*.000000.*" hifitts2/lhotse_shar/{cuts,target_audio,context_audio} -hifitts2/lhotse_shar/cuts -└── cuts.000000.jsonl.gz -hifitts2/lhotse_shar/target_audio -└── recording.000000.tar -hifitts2/lhotse_shar/context_audio -└── recording.000000.tar -``` - -### Step 4: Extend Lhotse Shards with Audio Codes - -This final step runs on **GPU nodes**. It processes the Lhotse shards created in Step 3 to extract and add audio codes. - -Improvements include: -- Pre-allocation of Lhotse shards to each rank, with each rank processing and writing independently. -- Pre-allocation of padded audio tensors, avoiding looped calls to `torch.func.pad`. -- Avoids redundant zero-padding that was present in older recipes. - -#### Example command: -```bash -# General setup -CODE_DIR="/workspace/NeMo" -export PYTHONPATH="${CODE_DIR}:${PYTHONPATH}" -cd ${CODE_DIR} - -# Codec parameters -CODEC_MODEL_NAME="21fpsCausalDecoder" -CODEC_MODEL_PATH="/path/to/your/codec_model.nemo" -CODEC_FRAME_RATE=21.5 - -# Trainer parameters -DEVICES=-1 # Number of GPUs, -1 for all -NUM_NODES=1 -BATCH_SIZE=64 -WANDB_ENTITY="xyz" -WANDB_PROJECT="xyz" -WANDB_NAME="xyz" - -# Path parameters -SHARD_DIR="/path/to/hifitts2/lhotse_shar" # From Step 3 - -echo "****** STEP 4: Extending Lhotse Shards with Audio Codes ******" -python scripts/magpietts/extend_lhotse_shards_with_audio_codes.py \ - --cuts-dir ${SHARD_DIR}/cuts \ - --target-audio-dir ${SHARD_DIR}/target_audio \ - --context-audio-dir ${SHARD_DIR}/context_audio \ - --output-dir ${SHARD_DIR} \ - --codec-model-name ${CODEC_MODEL_NAME} \ - --codec-model-path ${CODEC_MODEL_PATH} \ - --codec-frame-rate ${CODEC_FRAME_RATE} \ - --devices ${DEVICES} \ - --num-nodes ${NUM_NODES} \ - --batch-size ${BATCH_SIZE} \ - --wandb-entity ${WANDB_ENTITY} \ - --wandb-project ${WANDB_PROJECT} \ - --wandb-name ${WANDB_NAME} \ - --log-level 'INFO' \\ -``` - -### Checking the Outputs - -After running all four steps, you can check the files by looking at the output directory specified in Steps 3 and 4. - -```bash -# Examples of shard files: -$ tree -L 3 -P '*.000000.*' -I log hifitts2/lhotse_shar -hifitts2/lhotse_shar -├── codes_21fpsCausalDecoder # This is the subdir for audio codec codes. -│   ├── context_codes -│   │   └── codes.000000.tar # context audio codec codes. -│   └── target_codes -│   └── codes.000000.tar # target codec codes. -├── context_audio -│   └── recording.000000.tar # context audio waveforms. -├── cuts -│   └── cuts.000000.jsonl.gz # Lhotse manifest. -└── target_audio - └── recording.000000.tar # target audio waveforms. -``` - -When peek one of the item from `cuts.000000.jsonl.gz`, you should expect the structure as, -```python -In [4]: cutset = CutSet.from_shar(fields={"cuts": ["hifitts2/lhotse_shar/cuts/cuts.000000.jsonl.gz"], "target_audio": ["hifitts2/lhotse_shar/target_audio/recording.000000.tar"], "context_audio": ["h - ...: ifitts2/lhotse_shar/context_audio/recording.000000.tar"], "target_codes": ["hifitts2/lhotse_shar/codes_21fpsCausalDecoder/target_codes/codes.000000.tar"], "context_codes": ["hifitts2/lhotse_ - ...: shar/codes_21fpsCausalDecoder/context_codes/codes.000000.tar"]}) - -In [5]: cuts_list = [cut for cut in cutset] - -In [12]: from rich import print - -In [13]: print(cuts_list[0]) -MonoCut( - id='cut-rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27-0.00-3.49', - start=0.0, - duration=3.49, - channel=0, - supervisions=[ - SupervisionSegment( - id='sup-rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27', - recording_id='rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27', - start=0.0, - duration=3.49, - channel=0, - text='he was perhaps five years my senior', - language='en', - speaker='| Language:en Dataset:hifitts2 Speaker:9216 |', - gender=None, - custom={ - 'normalized_text': 'He was perhaps five years my senior.', - 'text_source': 'mls', - 'wer': 0.0, - 'cer': 0.0, - 'speaker_count': 1, - 'bandwidth': 13953, - 'set': 'train', - 'context_speaker_similarity': 0.802838921546936, - 'context_audio_offset': 0.0, - 'context_audio_duration': 12.2, - 'context_audio_text': 'Vision of Helen," he called it, I believe.... The oblique stare of the hostile Trojans. Helen coifed with flame. Menelaus.', - 'context_audio_normalized_text': 'Vision of Helen," he called it, I believe.... The oblique stare of the hostile Trojans. Helen coifed with flame. Menelaus.', - 'context_recording_id': 'rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_30' - }, - alignment=None - ) - ], - features=None, - recording=Recording( - id='rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27', - sources=[AudioSource(type='file', channels=[0], source='/audio/9216/8716/9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27.flac')], - sampling_rate=22050, - num_samples=76955, - duration=3.4900226757369612, - channel_ids=[0], - transforms=None - ), - custom={ - 'target_audio': Recording( - id='cut-rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_27-0.00-3.49', - sources=[AudioSource(type='memory', channels=[0], source='')], - sampling_rate=22050, - num_samples=76955, - duration=3.49, - channel_ids=[0], - transforms=None - ), - 'context_audio': Recording( - id='context_cut-rec-9216-8716-9216_8716_ohenryawardstoriesof1921_1409_librivox-ohenryawardprizestoriesof1921_07_various_30-0.00-12.20', - sources=[AudioSource(type='memory', channels=[0], source='')], - sampling_rate=22050, - num_samples=269010, - duration=12.2, - channel_ids=[0], - transforms=None - ), - 'target_codes': TemporalArray( - array=Array(storage_type='memory_npy', storage_path='', storage_key='', shape=[8, 76]), - temporal_dim=-1, - frame_shift=0.046511627906976744, - start=0 - ), - 'context_codes': TemporalArray( - array=Array(storage_type='memory_npy', storage_path='', storage_key='', shape=[8, 263]), - temporal_dim=-1, - frame_shift=0.046511627906976744, - start=0 - ), - 'shard_origin': 'hifitts2/lhotse_shar/cuts/cuts.000000.jsonl.gz', - 'shar_epoch': 0 - } -) -``` - -## Extending the Existing Lhotse Shar with New Audio Codec Codes -Given existing Lhotse Shar, i.e. cuts/target_audio/context_audio, you could just run the Python script -`scripts/magpietts/extend_lhotse_shards_with_audio_codes.py` by overriding with other audio codec models. The whole -process should be the same as Step 4 as mentioned above. - -## (Internal Only) Sharing Slurm Job Sub Scripts to Create Lhotse Shar -The internal scripts for submitting these steps as Slurm jobs can be found in the GitLab repository `nemo-tts-artifacts-registry` -repository, i.e. https://gitlab-master.nvidia.com/xueyang/nemo-tts-artifacts-registry/-/tree/model_release_2505/model_release_2505/data_prep. These scripts are tailored for specific cluster environments but can be adapted for other systems. - -```shell -$ tree -L 1 gitlab/nemo-tts-artifacts-registry/model_release_2505/data_prep/ -gitlab/nemo-tts-artifacts-registry/model_release_2505/data_prep/ -├── 1_submit_jobs_prep_input_manifest_iad.sh -├── 2_submit_jobs_extend_nemo_manifest_with_context_audio_iad.sh -├── 3_submit_jobs_create_lhotse_shards_from_nemo_manifest_iad.sh -├── 4_submit_jobs_extend_lhotse_shards_with_audio_codes_iad.sh -├── hifitts -├── hifitts2 -├── jhsdGtc20Amp20Keynote -├── libritts -├── librittsDevClean -├── nvyt2505 -├── README.md -├── rivaEmmaMeganSeanTom -└── rivaLindyRodney -``` - diff --git a/scripts/magpietts/README_magpie_po.md b/scripts/magpietts/README_magpie_po.md index 4fee74e3b72f..8ba3d2e66655 100644 --- a/scripts/magpietts/README_magpie_po.md +++ b/scripts/magpietts/README_magpie_po.md @@ -4,12 +4,12 @@ Code: `nemo/collections/tts/models/magpietts_preference_optimization.py` Preference Alignment (DPO/RPO) involves the following steps 1) Create a list of text-context pairs for which we will generate preference data. -2) For each text-context pair generate multiple audios from a base T5-TTS checkpoint and calculate metrics (CER/SSIM) for each generation. +2) For each text-context pair generate multiple audios from a base TTS checkpoint and calculate metrics (CER/SSIM) for each generation. 3) Create chosen-rejected pairs from the generated audio. -4) Finetune the base T5-TTS checkpoint on the chosen-rejected pairs. +4) Finetune the base TTS checkpoint on the chosen-rejected pairs. #### 1. Create text-context pairs -We pair a list of challenging texts with context audios from from Riva and LibriTTS dataset. We add a similar number of regular texts from LibriTTS and Riva (paired with random context audios). We also include examples with text contexts. There are other options for generating text-context pairs. +We pair a list of challenging texts with context audios from our speech datasets. We add a similar number of regular transcripts our datasets such as LibriTTS paired with random context audios. We also include examples with text contexts. There are other options for generating text-context pairs. ``` python scripts/magpietts/dpo/create_text_contextpairs.py \ @@ -23,11 +23,11 @@ python scripts/magpietts/dpo/create_text_contextpairs.py \ ``` Each pair is repeated `nsamples_perpair` times which specifies how many samples we want to generate for each pair. The output manifest serves as the input for the next step. -We can also explore other options for these text-context pairs as well depending on the task. +We can also explore other options for these text-context pairs as well depending on the task. #### 2. Generate audios for each text-context pair -Next, we can generate audios from a base T5-TTS checkpoint using the following command. We pass the `audio_dir` as "/" since our text context pairs contains absolute paths. Model config arguments should be modified accordingly to match the base checkpoint architecture. We can run the below command on cluster to generate audios across multiple nodes. This command saves the generated audios along with the metrics for each generation in the `exp_dir`. Each generated audio file is accompanied with a `.json` file that has the CER/SSIM metrics. +Next, we can generate audios from a base TTS checkpoint using the following command. We pass the `audio_dir` as "/" since our text context pairs contains absolute paths. Model config arguments should be modified accordingly to match the base checkpoint architecture. We can run the below command on cluster to generate audios across multiple nodes. This command saves the generated audios along with the metrics for each generation in the `exp_dir`. Each generated audio file is accompanied with a `.json` file that has the CER/SSIM metrics. ``` @@ -35,31 +35,28 @@ python examples/tts/magpietts.py \ --config-name=magpietts_inference_en \ mode=test \ batch_size=64 \ -+init_from_ptl_ckpt="/mountdir/checkpoints/continuouscheckpoints_ks1_ks3/decodercontext_small_282.ckpt" \ -exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282" \ -+test_ds_meta.textcontextpairs.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json" \ ++init_from_ptl_ckpt= \ +exp_manager.exp_dir= \ ++test_ds_meta.textcontextpairs.manifest_path= \ +test_ds_meta.textcontextpairs.audio_dir="/" \ +test_ds_meta.textcontextpairs.feature_dir="/" \ -model.model_type="decoder_context_tts" \ -model.encoder.kernel_size=3 \ -model.decoder.kernel_size=1 \ +model.model_type="decoder_context_tts" # Change this as needed \ model.context_duration_min=5.0 \ model.context_duration_max=5.0 \ model.use_text_conditioning_encoder=true \ -model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.codecmodel_path= \ model.alignment_loss_scale=0.002 \ model.prior_scaling_factor=null \ -model.load_cached_codes_if_available=false \ -trainer.num_nodes=${SLURM_JOB_NUM_NODES} +model.load_cached_codes_if_available=false ``` #### 3. Create chosen-rejected pairs from the generations -Next, we go through the generated audio directory and create chosen-rejected pairs. +Next, we go through the generated audio directory and create chosen-rejected pairs. ``` python scripts/magpietts/dpo/create_preference_pairs.py \ ---input_manifest /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json \ ---generated_audio_dir /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/audios \ +--input_manifest \ +--generated_audio_dir /Magpie-TTS-EN-Infer/version_0/audio \ --group_size 6 \ --cer_threshold 0.01 \ --val_size 256 ; @@ -67,7 +64,7 @@ python scripts/magpietts/dpo/create_preference_pairs.py \ `cer_threshold=0.01` means that filter out pairs in which the chosen CER > 0.01. -This command should save train and val manifests for DPO finetuning in the base directory of the generated_audio_dir, that is, `/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/` +This command should save train and val manifests for DPO finetuning in the base directory of the generated_audio_dir, that is, `/Magpie-TTS-EN-Infer/version_0/manifests/` #### 4. DPO Finetuning Command @@ -76,36 +73,35 @@ Finally, we perform DPO finetuning using the following command: ``` python examples/tts/magpietts.py \ batch_size=4 \ -+init_from_ptl_ckpt="/mountdir/checkpoints/decoder_21_epoch_2.ckpt" \ ++init_from_ptl_ckpt= \ +mode="dpo_train" \ max_epochs=10 \ -exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/TrainingsICML/decodercontext_small_282" \ +exp_manager.exp_dir= \ exp_manager.checkpoint_callback_params.always_save_nemo=false \ model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ -+train_ds_meta.dpopreftrain.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_train_manifest.json" \ ++train_ds_meta.dpopreftrain.manifest_path="/Magpie-TTS-EN-Infer/version_0/manifests/" \ +train_ds_meta.dpopreftrain.audio_dir="/" \ +train_ds_meta.dpopreftrain.feature_dir="/" \ -+val_ds_meta.dpoprefval.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_val_manifest.json" \ ++val_ds_meta.dpoprefval.manifest_path="/Magpie-TTS-EN-Infer/version_0/manifests/dpo_val_manifest.json" \ +val_ds_meta.dpoprefval.audio_dir="/" \ +val_ds_meta.dpoprefval.feature_dir="/" \ +model.dpo_beta=0.01 \ +model.dpo_sft_loss_weight=0.0 \ -model.model_type="decoder_context_tts" \ +model.model_type="decoder_context_tts" # Change this as needed \ model.context_duration_min=5.0 \ model.context_duration_max=5.0 \ model.use_text_conditioning_encoder=true \ -model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ +model.codecmodel_path= \ model.alignment_loss_scale=0.001 \ model.prior_scaling_factor=null \ trainer.val_check_interval=200 \ trainer.log_every_n_steps=10 \ model.optim.lr=2e-7 \ -~model.optim.sched \ -trainer.num_nodes=${SLURM_JOB_NUM_NODES} +~model.optim.sched ``` -Note the following overrides in the above command: +Note the following overrides in the above command: ``` +mode="dpo_train" \ @@ -138,11 +134,11 @@ To train with GRPO, we use a similar training command as the base model training 1. We start from a pretrained checkpoint supplied using `+init_from_ptl_ckpt` 2. We add `+mode="onlinepo_train"` to specify preference optimization based training. -3. Use a small batch size (bs=2) since we generate `num_generations_per_item` samples per item in the batch and the effective batch size becomes `bs*num_generations_per_item` +3. Use a small batch size (bs=2) since we generate `num_generations_per_item` samples per item in the batch and the effective batch size becomes `bs*num_generations_per_item` 4. The manifest should contain absolute audio paths and the `audio_dir` is specified as "/" in the `train_ds_meta` command. 5. Use the same model specific overrides as the base model (eg. x-attn heads, is_causal, num_layers, local transformer etc). 6. Set dropout probs to 0 for all modules - This is especially important if we are not using reference free mode. KL divergence loss becomes very spiky and unstable. Set prob to 0 by `model.decoder.p_dropout=0.0`. -7. Dont use attention prior or CTC loss during GRPO. +7. Dont use attention prior or CTC loss during GRPO. 8. Add the following GRPO specific arguments in the training command. ``` @@ -191,24 +187,24 @@ Below is a sample training command for multilingual GRPO: ``` python examples/tts/magpietts.py \ ---config-name=magpietts_multilingual_v1 \ +--config-name=magpietts_multilingual_v1 #TODO(blisc) after updating yamls\ batch_size=2 \ -+init_from_ptl_ckpt="/mountdir/checkpoints/magpie_checkpoints/shared_char_ipa_epoch285.ckpt" \ ++init_from_ptl_ckpt= \ +mode="onlinepo_train" \ -+model.text_tokenizers.chartokenizer._target_=AutoTokenizer \ -+model.text_tokenizers.chartokenizer.pretrained_model="google/byt5-small" \ ++model.text_tokenizers.chartokenizer._target_=AutoTokenizer # Change this as needed \ ++model.text_tokenizers.chartokenizer.pretrained_model="google/byt5-small" # Change this as needed \ max_epochs=20 \ -exp_manager.exp_dir="${DOCKER_EXP_DIR}" \ +exp_manager.exp_dir= \ +exp_manager.version=0 \ exp_manager.checkpoint_callback_params.always_save_nemo=false \ -+train_ds_meta.dpopreftrain.manifest_path="/data/TTS/CML/manifests_with_codecs_ipa3/cml_tts_dataset_portuguese_v0.1/train_withAudioCodes_codec21KhzCausalDecoder_filtered_textcontextpairs_train_GRPO_ipa_NoDuplicates.json" \ ++train_ds_meta.dpopreftrain.manifest_path= \ +train_ds_meta.dpopreftrain.audio_dir="/" \ +train_ds_meta.dpopreftrain.feature_dir="/" \ -+train_ds_meta.dpopreftrain.tokenizer_names="[chartokenizer]" \ -+val_ds_meta.dpoprefval.manifest_path="/data/TTS/CML/manifests_with_codecs_ipa3/cml_tts_dataset_portuguese_v0.1/train_withAudioCodes_codec21KhzCausalDecoder_filtered_textcontextpairs_val_GRPO_ipa.json" \ ++train_ds_meta.dpopreftrain.tokenizer_names="[chartokenizer]" #Change this as needed \ ++val_ds_meta.dpoprefval.manifest_path= \ +val_ds_meta.dpoprefval.audio_dir="/" \ +val_ds_meta.dpoprefval.feature_dir="/" \ -+val_ds_meta.dpoprefval.tokenizer_names="[chartokenizer]" \ ++val_ds_meta.dpoprefval.tokenizer_names="[chartokenizer]" #Change this as needed \ +model.grpo_beta=0.0 \ +model.num_generations_per_item=12 \ +model.reference_free=true \ @@ -226,17 +222,17 @@ model.cfg_unconditional_prob=0.0 \ +model.loss_type="grpo" \ +model.scale_rewards=true \ +model.max_decoder_steps=430 \ -model.model_type="decoder_context_tts" \ +model.model_type="decoder_context_tts" #Change this as needed \ model.context_duration_min=5.0 \ model.context_duration_max=5.0 \ model.decoder.p_dropout=0.0 \ model.encoder.p_dropout=0.0 \ -model.local_transformer_type="autoregressive" \ -model.local_transformer_n_layers=1 \ -model.local_transformer_n_heads=1 \ -model.local_transformer_hidden_dim=256 \ -model.use_text_conditioning_encoder=true \ -model.codecmodel_path="/mountdir/checkpoints/21fps_causal_codecmodel.nemo" \ +model.local_transformer_type="autoregressive" #Change this as needed \ +model.local_transformer_n_layers=1 #Change this as needed \ +model.local_transformer_n_heads=1 #Change this as needed \ +model.local_transformer_hidden_dim=256 #Change this as needed \ +model.use_text_conditioning_encoder=true #Change this as needed \ +model.codecmodel_path= \ model.alignment_loss_scale=0.0 \ model.prior_scaling_factor=null \ ~trainer.check_val_every_n_epoch \ @@ -248,6 +244,5 @@ exp_manager.checkpoint_callback_params.monitor="val_cer_gt" \ exp_manager.checkpoint_callback_params.mode="min" \ trainer.precision=32 \ +trainer.gradient_clip_val=2.5 \ -trainer.num_nodes=${SLURM_JOB_NUM_NODES} ``` diff --git a/scripts/magpietts/codec_extraction.py b/scripts/magpietts/codec_extraction.py deleted file mode 100644 index 3ecb42bc1504..000000000000 --- a/scripts/magpietts/codec_extraction.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import json -import os - -import lightning.pytorch as pl -import torch -from lightning.pytorch import Trainer -from lightning.pytorch.strategies import DDPStrategy -from lightning.pytorch.utilities import rank_zero_only -from torch.utils.data import DataLoader, Dataset - -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.tts.models import AudioCodecModel - - -class AudioDataset(Dataset): - def __init__(self, file_lists, base_audio_dirs, dataset_names, out_dir, sample_rate=22050, pad_multiple=1024): - self.file_list = file_lists - self.base_audio_dirs = base_audio_dirs - self.sample_rate = sample_rate - self.pad_multiple = pad_multiple - self.out_dir = out_dir - self.combined_file_list = [] - for fidx, file_list in enumerate(file_lists): - base_audio_dir = base_audio_dirs[fidx] - dataset_name = dataset_names[fidx] - for file_path in file_list: - audio_file_path = os.path.join(base_audio_dir, file_path) - self.combined_file_list.append( - {"file_path": file_path, "audio_file_path": audio_file_path, "dataset_name": dataset_name} - ) - - def __len__(self): - return len(self.combined_file_list) - - def get_wav_from_filepath(self, file_path): - features = AudioSegment.segment_from_file( - file_path, - target_sr=self.sample_rate, - n_segments=-1, - trim=False, - ) - audio_samples = features.samples - audio = torch.tensor(audio_samples) - audio = torch.nn.functional.pad(audio, (0, self.pad_multiple - audio.size(0) % self.pad_multiple), value=0) - audio_length = torch.tensor(audio.size(0)).long() - return audio, audio_length - - def __getitem__(self, idx): - file_path = self.combined_file_list[idx]["file_path"] - audio_file_path = self.combined_file_list[idx]["audio_file_path"] - dataset_name = self.combined_file_list[idx]["dataset_name"] - assert not file_path.startswith("/"), "file_path should be relative" - audio, audio_length = self.get_wav_from_filepath(audio_file_path) - codec_file_path_rel = file_path.replace(".wav", ".pt").replace(".flac", ".pt") - return { - "audio": audio, - "audio_length": audio_length, - "file_path": file_path, - "codec_file_path": os.path.join(self.out_dir, dataset_name, codec_file_path_rel), - } - - def collate_fn(self, batch): - audios_padded = [] - audio_lengths = [] - file_paths = [] - codec_file_paths = [] - max_audio_length = max(item["audio_length"].item() for item in batch) - for item in batch: - audio = torch.nn.functional.pad(item["audio"], (0, max_audio_length - item["audio"].size(0)), value=0) - audios_padded.append(audio) - audio_lengths.append(item["audio_length"]) - file_paths.append(item["file_path"]) - codec_file_paths.append(item["codec_file_path"]) - - return { - "audios": torch.stack(audios_padded), - "audio_lengths": torch.stack(audio_lengths), - "audio_file_paths": file_paths, - "codec_file_paths": codec_file_paths, - } - - -class CodecExtractor(pl.LightningModule): - def __init__(self, model_path): - super().__init__() - self.codec_model = AudioCodecModel.restore_from(model_path, strict=False) - self.codec_model.eval() - - def forward(self, batch): - with torch.cuda.amp.autocast(enabled=False): - with torch.no_grad(): - codes, codes_lengths = self.codec_model.encode(audio=batch["audios"], audio_len=batch["audio_lengths"]) - return codes, codes_lengths - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - codes, codes_lengths = self(batch) - for i, file_path in enumerate(batch["codec_file_paths"]): - # get directory from file path - item_codes = codes[i, :, : codes_lengths[i]] # 8, T - torch.save(item_codes.cpu().type(torch.int16), file_path) - return None - - -def read_manifest(manifest_path): - records = [] - with open(manifest_path, 'r') as f: - all_lines = f.readlines() - for line in all_lines: - line = line.strip() - record = json.loads(line) - records.append(record) - return records - - -def write_manifest(manifest_path, records): - with open(manifest_path, 'w') as f: - file_str = "" - for record in records: - file_str += json.dumps(record) + "\n" - file_str = file_str.strip() - f.write(file_str) - print("Wrote {} records to: {}".format(len(records), manifest_path)) - - -@rank_zero_only -def update_manifests(manifests, save_dir, dataset_names, codec_model_name): - for midx, manifest in enumerate(manifests): - records = read_manifest(manifest) - for ridx, record in enumerate(records): - audio_codes_path = record["audio_filepath"].replace(".wav", ".pt").replace(".flac", ".pt") - audio_codes_path = os.path.join(save_dir, dataset_names[midx], audio_codes_path) - record["target_audio_codes_path"] = audio_codes_path - if ridx % 10 == 0: - assert os.path.exists(audio_codes_path), "Audio codes not found: {}".format(audio_codes_path) - - if "context_audio_filepath" in record: - context_audio_codes_path = ( - record["context_audio_filepath"].replace(".wav", ".pt").replace(".flac", ".pt") - ) - context_audio_codes_path = os.path.join(save_dir, dataset_names[midx], context_audio_codes_path) - record["context_audio_codes_path"] = context_audio_codes_path - if ridx % 10 == 0: - assert os.path.exists(context_audio_codes_path), "Context audio codes not found: {}".format( - context_audio_codes_path - ) - - write_manifest(manifest.replace(".json", "_withAudioCodes_{}.json".format(codec_model_name)), records) - - -def prepare_directories(base_save_dir, codec_model_name, manifests, audio_base_dirs, dataset_names): - print("In prepare_directories") - save_dir = os.path.join(base_save_dir, codec_model_name) - file_lists = [] - for midx, manifest in enumerate(manifests): - records = read_manifest(manifest) - unique_audio_file_paths = {} - for record in records: - unique_audio_file_paths[record["audio_filepath"]] = 1 - if "context_audio_filepath" in record: - unique_audio_file_paths[record["context_audio_filepath"]] = 1 - file_list = list(unique_audio_file_paths.keys()) - file_lists.append(file_list) - for file_path in file_list: - dir_path = os.path.dirname(file_path) - out_dir_path = os.path.join(save_dir, dataset_names[midx], dir_path) - if not os.path.exists(out_dir_path): - os.makedirs(out_dir_path, exist_ok=True) - print("Created directories for saving audio codes at: ", save_dir, len(file_lists)) - return save_dir, file_lists - - -if __name__ == "__main__": - """ - Usage: - python scripts/magpietts/codec_extraction.py \ - --manifests /home/pneekhara/2023/SimpleT5NeMo/manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths.json \ - --audio_base_dirs /datap/misc/Datasets/VCTK-Corpus \ - --codec_model_name codec21Khz_no_eliz \ - --dataset_names smallvctk \ - --save_dir /home/pneekhara/2023/SimpleT5NeMo/codec_outputs_21Khz \ - --codec_model_path /datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo \ - --sample_rate 22050 \ - --pad_multiple 1024 \ - --devices -1 \ - --num_nodes 1 - """ - parser = argparse.ArgumentParser() - parser.add_argument("--manifests", type=str) - parser.add_argument("--audio_base_dirs", type=str) - parser.add_argument("--dataset_names", type=str) - parser.add_argument("--save_dir", type=str) - parser.add_argument("--codec_model_path", type=str) - parser.add_argument("--codec_model_name", type=str) - parser.add_argument("--sample_rate", type=int) - parser.add_argument("--pad_multiple", type=int) - parser.add_argument("--devices", type=int, default=-1) - parser.add_argument("--num_nodes", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=16) - parser.add_argument("--num_workers", type=int, default=4) - args = parser.parse_args() - - trainer = Trainer( - devices=args.devices, - accelerator="gpu", - strategy=DDPStrategy(find_unused_parameters=False), - num_nodes=args.num_nodes, - log_every_n_steps=1, - max_epochs=1, - logger=False, - ) - - audio_base_dirs = args.audio_base_dirs.split(",") - dataset_names = args.dataset_names.split(",") - manifests = args.manifests.split(",") - - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - if rank == 0: - save_dir, file_lists = prepare_directories( - args.save_dir, args.codec_model_name, manifests, audio_base_dirs, dataset_names - ) - results = [save_dir, file_lists] - else: - results = [None, None] - torch.distributed.broadcast_object(results, src=0) - save_dir, file_lists = results - else: - save_dir, file_lists = prepare_directories( - args.save_dir, args.codec_model_name, manifests, audio_base_dirs, dataset_names - ) - - codec_extractor = CodecExtractor(args.codec_model_path) - - # Dataset and DataLoader - dataset = AudioDataset( - file_lists=file_lists, - base_audio_dirs=audio_base_dirs, - dataset_names=dataset_names, - out_dir=save_dir, - sample_rate=args.sample_rate, - pad_multiple=args.pad_multiple, - ) - - dataloader = DataLoader( - dataset, - batch_size=args.batch_size, - num_workers=args.num_workers, - shuffle=False, - collate_fn=dataset.collate_fn, - ) - - # Run prediction (Saves the audio codes to files) - trainer.predict(codec_extractor, dataloaders=dataloader) - - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - update_manifests(manifests, save_dir, dataset_names, args.codec_model_name) diff --git a/scripts/magpietts/create_text_context_lhotse_manifest.py b/scripts/magpietts/create_text_context_lhotse_manifest.py deleted file mode 100644 index b229d2bfd2b7..000000000000 --- a/scripts/magpietts/create_text_context_lhotse_manifest.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Create Text Context Lhotse Manifest Script - -This script converts MagpieTTS Lhotse manifest files from audio-based context to text-based context -for model training. It processes sharded datasets by extracting speaker and suffix information from -supervision IDs and replacing complex audio context metadata with simplified text context strings. - -The script supports three datasets: -- rivaLindyRodney -- rivaEmmaMeganSeanTom -- jhsdGtc20Amp20Keynote - -For each dataset, it: -1. Finds and validates all shard files in the input directory -2. Processes each shard by replacing audio context with text context -3. Saves the modified shards to a new directory with "_textContext" suffix -4. Validates the output by inspecting the last processed cut - -The script expects the input data to be organized as: -``` -./model_release_2505/lhotse_shar/{dataset}/lhotse_shar_shuffle_shardSize256/cuts/ -├── cuts.000000.jsonl.gz -├── cuts.000001.jsonl.gz -└── cuts.000002.jsonl.gz -``` - -Output will be saved to: -``` -./model_release_2505/lhotse_shar/{dataset}/lhotse_shar_shuffle_shardSize256/cuts_textContext/ -├── cuts.000000.jsonl.gz -├── cuts.000001.jsonl.gz -└── cuts.000002.jsonl.gz -``` - -Usage: - python create_text_context_lhotse_manifest.py - - -**BEFORE (Original Audio Context):** -The input cut contains audio context information with references to other audio files: -``` -custom={ - 'emotion': 'happy', - 'context_speaker_similarity': 0.8681941628456116, - 'context_audio_offset': 0.0, - 'context_audio_duration': 4.17, - 'context_audio_text': 'The river bared its bosom, and snorting steamboats challenged the wilderness.', - 'context_recording_id': 'rec-Rodney-44khz-CMU_HAPPY-RODNEY_CMU_HAPPY_000487' -} -``` - -**AFTER (Text Context):** -The output cut contains simplified text context information: -``` -custom={ - 'context_text': 'Speaker and Emotion: | Language:en Dataset:rivaLindyRodney Speaker:Rodney_CMU_HAPPY |', - 'emotion': 'happy' # preserved from original -} -``` -""" - -import glob -import logging -import os -import re -from functools import partial - -from lhotse import CutSet -from rich import print -from tqdm import tqdm - - -def batch_replace_and_write(cut_filepath, new_cut_filepath, dataset_name): - """ - Process a single Lhotse shard file by replacing audio context with text context. - - This function loads a CutSet from a shard file, applies the text context transformation - to each cut, and saves the modified CutSet to a new file. - - Args: - cut_filepath (str): Path to the input shard file (e.g., cuts.000000.jsonl.gz) - new_cut_filepath (str): Path where the modified shard file will be saved - dataset_name (str): Name of the dataset being processed, used to determine - how to parse supervision IDs for speaker information - """ - print(f" Processing {dataset_name}: {cut_filepath} --> {new_cut_filepath}") - cuts = CutSet.from_file(cut_filepath) - cuts_with_validation = cuts.map(partial(replace_audio_context_with_text_context, dataset_name=dataset_name)) - cuts_with_validation.to_file(new_cut_filepath) - - -def replace_audio_context_with_text_context(cut, dataset_name): - """ - Replace audio context information with text context for a single cut. - - This function extracts speaker and speaker suffix information from the supervision ID - and creates a text-based context string. The parsing logic varies by dataset - due to different ID formats. - - Args: - cut: A Lhotse Cut object containing audio and supervision information - dataset_name (str): Name of the dataset, determines parsing logic: - - "rivaLindyRodney": Uses items[4] as speaker suffix - - "rivaEmmaMeganSeanTom": Extracts middle parts of items[4] split by "_" - - "jhsdGtc20Amp20Keynote": Uses items[3] as speaker suffix - - Returns: - cut: The modified Cut object with updated custom context information - - Raises: - ValueError: If dataset_name is not one of the supported datasets - - Example: - For a cut with speaker "Rodney" and supervision ID "sup-rec-Rodney-44khz-CMU_HAPPY-RODNEY_CMU_HAPPY_000452", - this might create context_text: "Speaker and Emotion: | Language:en Dataset:rivaLindyRodney Speaker:Rodney_CMU_HAPPY |" - """ - speaker = cut.supervisions[0].speaker - seg_id = cut.supervisions[0].id - items = seg_id.split("-") - - if dataset_name == "rivaLindyRodney": - speaker_suffix = items[4] - elif dataset_name == "rivaEmmaMeganSeanTom": - speaker_suffix = "_".join(items[4].split("_")[1:-1]) - elif dataset_name == "jhsdGtc20Amp20Keynote": - speaker_suffix = items[3] - else: - raise ValueError(f"Unknown dataset name: {dataset_name}") - - text_context = f"Speaker and Emotion: {speaker.rstrip('| ')}_{speaker_suffix} |" - new_custom = {"context_text": text_context} - - # keep original emotion state if any. - if cut.supervisions[0].has_custom("emotion"): - new_custom.update({"emotion": cut.supervisions[0].emotion}) - - cut.supervisions[0].custom = new_custom - - return cut - - -def find_and_verify_shards(cuts_dir: str): - """ - Find and validate all Lhotse shard files in the specified directory. - - This function searches for shard files matching the pattern "cuts.*.jsonl.gz" - and verifies that the shard indices are contiguous starting from 0. This ensures - that all shards are present and properly numbered for processing. - - Args: - cuts_dir (str): Directory path containing the shard files - - Returns: - list[str]: Sorted list of paths to all shard files - - Raises: - FileNotFoundError: If no shard files are found matching the expected pattern - ValueError: If shard indices are not contiguous or don't start from 0 - - Example: - If cuts_dir contains files: cuts.000000.jsonl.gz, cuts.000001.jsonl.gz, cuts.000002.jsonl.gz - Returns: ['/path/to/cuts.000000.jsonl.gz', '/path/to/cuts.000001.jsonl.gz', '/path/to/cuts.000002.jsonl.gz'] - """ - cuts_shard_pattern = os.path.join(cuts_dir, "cuts.*.jsonl.gz") - all_cuts_shard_paths = sorted(glob.glob(cuts_shard_pattern)) - - if not all_cuts_shard_paths: - msg = f"No input cut shards found matching pattern: {cuts_shard_pattern}. Cannot proceed." - logging.error(msg) - raise FileNotFoundError(msg) - - num_total_shards = len(all_cuts_shard_paths) - - # Verify shard indices are contiguous and start from 0 based on filenames (globally) - first_idx_str = re.search(r"cuts\.(\d+)\.jsonl\.gz$", all_cuts_shard_paths[0]).group(1) - last_idx_str = re.search(r"cuts\.(\d+)\.jsonl\.gz$", all_cuts_shard_paths[-1]).group(1) - first_idx = int(first_idx_str) - last_idx = int(last_idx_str) - expected_last_idx = num_total_shards - 1 - if first_idx != 0: - raise ValueError(f"Expected first shard index to be 0, but found {first_idx} in {all_cuts_shard_paths[0]}") - if last_idx != expected_last_idx: - raise ValueError( - f"Expected last shard index to be {expected_last_idx}, but found {last_idx} in {all_cuts_shard_paths[-1]}" - ) - logging.info( - f"Verified {num_total_shards} total shard files globally, with indices from {first_idx} to {last_idx}." - ) - return all_cuts_shard_paths - - -if __name__ == "__main__": - datasets = ["rivaLindyRodney", "rivaEmmaMeganSeanTom", "jhsdGtc20Amp20Keynote"] - for dataset in datasets: - cut_dir = f"./model_release_2505/lhotse_shar/{dataset}/lhotse_shar_shuffle_shardSize256/cuts" - all_cuts_shard_paths = find_and_verify_shards(cut_dir) - cut_dir_tc = cut_dir + "_textContext" - os.makedirs(cut_dir_tc, exist_ok=True) - - for cut_filepath in tqdm(all_cuts_shard_paths, total=len(all_cuts_shard_paths)): - cut_basename = os.path.basename(cut_filepath) - cut_filepath_tc = os.path.join(cut_dir_tc, cut_basename) - batch_replace_and_write(cut_filepath, cut_filepath_tc, dataset_name=dataset) - - # validate - cuts = CutSet.from_file(cut_filepath_tc) - cuts_list = list() - for cut in cuts: - cuts_list.append(cut) - print(cuts_list[-1]) diff --git a/scripts/magpietts/eval_squimmos.py b/scripts/magpietts/eval_squimmos.py deleted file mode 100644 index af3a02cf381d..000000000000 --- a/scripts/magpietts/eval_squimmos.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License.from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE -import argparse -import os - -import librosa -import numpy as np -import scipy.stats as stats -import torch - - -def find_sample_audios(audio_dir): - file_list = [] - for f in os.listdir(audio_dir): - if "predicted_audio" in f and f.endswith(".wav"): - audio_number = int(f.split("_")[-1].split(".wav")[0]) - file_list.append((audio_number, os.path.join(audio_dir, f))) - file_list.sort() - file_list = [t[1] for t in file_list] - return file_list - - -def compute_mean_and_confidence_interval(measurements, confidence=0.95): - mean = np.mean(measurements) - std_err = stats.sem(measurements) - - confidence_interval = std_err * stats.t.ppf((1 + confidence) / 2, len(measurements) - 1) - - return "{:.4f} +/- {:.4f}".format(mean, confidence_interval), mean, confidence_interval - - -def main(): - parser = argparse.ArgumentParser(description='Evaluate Squim MOS') - parser.add_argument('--exp_base_dir', type=str, default="/datap/misc/ContinuousEvalResults/NewTransformerKoelTTS") - parser.add_argument( - '--audio_dirs', - type=str, - default="svencoder_small_sp_ks3_onlyphoneme_epoch242_Temp0.6_Topk80_Cfg_False_1.0_libri_val", - ) - args = parser.parse_args() - - device = "cuda" if torch.cuda.is_available() else "cpu" - squim_mos_model = SQUIM_SUBJECTIVE.get_model().to(device) - - squim_score_list = [] - if args.audio_dirs == "all": - audio_dirs = [d for d in os.listdir(args.exp_base_dir) if os.path.isdir(os.path.join(args.exp_base_dir, d))] - else: - audio_dirs = args.audio_dirs.split(",") - out_file = os.path.join(args.exp_base_dir, "squim_mos_score.csv") - for audio_dir in audio_dirs: - print("Evaluating audio dir: ", audio_dir) - audio_dir_path = os.path.join(args.exp_base_dir, audio_dir, "audio") - audio_files = find_sample_audios(audio_dir_path) - for audio_file in audio_files: - pred_wav, sr = librosa.load(audio_file, sr=16000) - pred_wav = torch.tensor(pred_wav).to(device).unsqueeze(0) - - gt_path = audio_file.replace("predicted_audio", "target_audio") - gt_wav, sr = librosa.load(gt_path, sr=16000) - gt_wav = torch.tensor(gt_wav).to(device).unsqueeze(0) - with torch.no_grad(): - squm_mos_score = squim_mos_model(pred_wav, gt_wav) - squim_score_list.append(squm_mos_score.item()) - - mean_with_ci, mean, confidence_interval = compute_mean_and_confidence_interval(squim_score_list) - # Add to audio_dir,mean_with_ci to csv - with open(out_file, "a") as f: - f.write(audio_dir + "," + mean_with_ci + "\n") - print("Audio dir: ", audio_dir, "Mean with CI: ", mean_with_ci) - print("Wrote to file: ", out_file) - - -if __name__ == "__main__": - main() diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py index 4e2f58575010..2380e8372ca9 100644 --- a/scripts/magpietts/evalset_config.py +++ b/scripts/magpietts/evalset_config.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Used as a datafile for infer_and_evaluate.py +""" dataset_meta_info = { 'riva_hard_digits': { 'manifest_path': '/Data/evaluation_manifests/hard-digits-path-corrected.ndjson', diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py index 20d93dcad3a4..a1b25705741b 100644 --- a/scripts/magpietts/evaluate_generated_audio.py +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Used in infer_and_evaluate.py to obtain metrics such as ASR_WER and UTMOSV2 scores. +""" import argparse import json import logging diff --git a/scripts/magpietts/extend_nemo_manifest_with_context_audio.py b/scripts/magpietts/extend_nemo_manifest_with_context_audio.py deleted file mode 100644 index fb678f02f8d0..000000000000 --- a/scripts/magpietts/extend_nemo_manifest_with_context_audio.py +++ /dev/null @@ -1,930 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -import logging -import os -import random -import re -from collections import defaultdict -from pathlib import Path - -import lightning.pytorch as pl -import torch -import wandb -from lhotse.dataset.collation import collate_vectors -from lightning.pytorch import Trainer -from lightning.pytorch.loggers import WandbLogger -from lightning.pytorch.strategies import DDPStrategy -from torch.utils.data import DataLoader, Dataset - -import nemo.collections.asr as nemo_asr -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment - -logger = logging.getLogger(__name__) - -# Constant for masking identical items in similarity matrix -# Set below valid cosine similarity range [-1, 1] to ensure masked items are never selected -MASKED_SIMILARITY_VALUE = -2.0 - -""" -Usage: -python scripts/magpietts/extend_manifest_with_context_audio.py - --manifest /path/to/input.json - --audio-base-dir /path/to/audio - --output-dir /path/to/output_sharded_manifests - --batch-size 16 - --devices 2 - --num-nodes 1 - --flush-threshold-items 20000 - --num-workers 4 - --context-min-duration 3.0 - --context-min-ssim 0.6 - --max-speaker-items 20000 # optional, prevents OOM for large speakers - -This script distributes speakers across DDP ranks. Each rank processes its assigned speakers -and writes a partial manifest. Rank 0 then merges these into a final manifest. - -The --max-speaker-items parameter limits the size of the context pool per speaker to prevent OOM -when computing similarity matrices. If a speaker has more items than this limit, a random -sample will be used as the context pool, but all items will still be processed to find -their best context from this pool. - -Input manifest example entry: -{ - "audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav", - "text": "the face.", - "speaker": "| Language:en Dataset:NVYT_2505 Speaker:_8Kirz57BTY_SPEAKER_01 |", - "offset": 2.8, - "duration": 0.48, - "bandwidth": 10125, - "stoi_squim": 0.98, - "sisdr_squim": 18.235, - "pesq_squim": 2.349, - "dataset_id": "369a9f1a-65eb-4c09-8de3-8babea29da4c", - "dataset_version": "2024_11_02_092919", - "dataset_name": "yt_mixed", - "normalized_text": "the face." -} - -Output manifest example entry: -{ - "audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav", - "text": "the face.", - "speaker": "| Language:en Dataset:NVYT_2505 Speaker:_8Kirz57BTY_SPEAKER_01 |", - "offset": 2.8, - "duration": 0.48, - "bandwidth": 10125, - "stoi_squim": 0.98, - "sisdr_squim": 18.235, - "pesq_squim": 2.349, - "dataset_id": "369a9f1a-65eb-4c09-8de3-8babea29da4c", - "dataset_version": "2024_11_02_092919", - "dataset_name": "yt_mixed", - "normalized_text": "the face.", - "context_audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav", - "context_audio_offset": 5.6, - "context_audio_duration": 6.0, - "context_audio_text": "would you mind..", - "context_audio_normalized_text": "would you mind..", - "context_audio_speaker_similarity": 0.85 -} -""" - - -def check_speaker_format(item: str): - """Enforce speaker format like '| Language:en Dataset:HiFiTTS Speaker:9136_other |'""" - pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" - if not isinstance(item, str): - return False - return bool(re.match(pattern, item)) - - -class SpeakerShardedAudioDataset(Dataset): - def __init__(self, assigned_records_list, base_audio_dir, sample_rate=16000): - self.sample_rate = sample_rate - self.base_audio_dir = base_audio_dir - self.processed_records = assigned_records_list - - def __len__(self): - return len(self.processed_records) - - def get_wav_from_filepath(self, audio_filepath_rel, offset_in_sec=0, duration_in_sec=None): - full_audio_filepath = os.path.join(self.base_audio_dir, audio_filepath_rel) - try: - features = AudioSegment.from_file( - audio_file=full_audio_filepath, - target_sr=self.sample_rate, - int_values=False, # TODO: if input is FLAC, then we should set this to True. - offset=offset_in_sec, - duration=duration_in_sec, - ) - except Exception as e: - logger.warning( - f"[Skipping Wav Load] Failed for `{full_audio_filepath}` (relative: `{audio_filepath_rel}`, offset={offset_in_sec}, duration={duration_in_sec}): {e}" - ) - return None, None - audio_samples = features.samples - return torch.tensor(audio_samples), torch.tensor(len(audio_samples)).long() - - def __getitem__(self, idx): - item_info = self.processed_records[idx] - - audio, audio_length = self.get_wav_from_filepath( - item_info["audio_filepath"], item_info["offset"], item_info["duration"] - ) - if audio is None or audio_length is None: - return None - - output_item = item_info.copy() - output_item.update( - { - "audio": audio, - "audio_length": audio_length, - } - ) - return output_item - - def collate_fn(self, batch): - valid_items = [item for item in batch if item is not None] - if not valid_items: - return { - "audios": torch.empty(0), - "audio_lengths": torch.empty(0), - "metadata_list": [], - "parsed_speaker_ids_list": [], - } - - audio_padded = collate_vectors([item["audio"] for item in valid_items], padding_value=0.0) - audio_lengths = torch.tensor([item["audio_length"] for item in valid_items]) - metadata_list = [ - {k: v for k, v in item.items() if k not in ['audio', 'audio_length', 'parsed_speaker_id']} - for item in valid_items - ] - parsed_speaker_ids_for_batch = [item['parsed_speaker_id'] for item in valid_items] - - return { - "audios": audio_padded, - "audio_lengths": audio_lengths, - "metadata_list": metadata_list, - "parsed_speaker_ids_list": parsed_speaker_ids_for_batch, - } - - -class EmbeddingSimilarityExtractorSharded(pl.LightningModule): - def __init__( - self, - output_dir: str, - output_file_prefix: str, - flush_threshold_items: int, - context_min_duration: float, - context_min_ssim: float, - speaker_expected_counts_map: dict, - initial_assigned_count: int, - max_speaker_items: int = None, - ): - super().__init__() - self.sv_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( - 'titanet_large', map_location=torch.device('cpu') - ) - self.sv_model.eval() - - self.output_dir = Path(output_dir) - self.output_file_prefix = output_file_prefix - self.flush_threshold_items = flush_threshold_items - self.context_min_duration = context_min_duration - self.context_min_ssim = context_min_ssim - self.speaker_expected_counts = speaker_expected_counts_map - self.initial_assigned_count = initial_assigned_count - self.max_speaker_items = max_speaker_items - - # Per-rank attributes - self.output_file_path = None - self.speaker_data_accumulator = defaultdict(list) - self.total_accumulated_items = 0 # total number of items accumulated across all speakers for this rank - self.processed_speakers_set = set() # set of speakers that have been processed and flushed - self.ready_to_flush_speaker_ids = set() # set of speakers that have accumulated enough items to be flushed - self.output_manifest_file = None - # total num of items discarded due to no suitable context for this rank - self.total_discarded_no_suitable_context_this_rank = 0 - self.total_items_written_this_rank = 0 # total items written to manifest by this rank - - def setup(self, stage: str): - if stage == "predict": - self.sv_model.to(self.device) - self.output_file_path = self.output_dir / f"{self.output_file_prefix}_rank{self.global_rank}.json" - self.output_dir.mkdir(parents=True, exist_ok=True) - self.output_manifest_file = open(self.output_file_path, "w", encoding="utf-8") - logger.info(f"Writing partial manifest to: `{self.output_file_path}`") - if self.max_speaker_items: - logger.info(f"Max speaker items limit set to: {self.max_speaker_items}") - else: - logger.info("No max speaker items limit set (potential OOM risk for very large speakers)") - logger.debug(f"Expected speaker counts for model: {self.speaker_expected_counts}") - - def forward(self, batch): - with torch.no_grad(): - _, speaker_embeddings = self.sv_model.forward( - input_signal=batch['audios'], - input_signal_length=batch['audio_lengths'], - ) - return speaker_embeddings - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - if batch['audios'].nelement() == 0: - return [] - - speaker_embeddings_gpu = self(batch) - - output_items_for_batch_end = [] - for i, single_metadata_item in enumerate(batch["metadata_list"]): - embedding_cpu_fp32 = speaker_embeddings_gpu[i].cpu().type(torch.float32) - base_speaker_id_for_item = batch["parsed_speaker_ids_list"][i] - - processed_item = { - "speaker_id_for_grouping": base_speaker_id_for_item, - "embedding": embedding_cpu_fp32, - "metadata": single_metadata_item, - } - output_items_for_batch_end.append(processed_item) - - return output_items_for_batch_end - - def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): - for item in outputs: - base_speaker_id = item['speaker_id_for_grouping'] - - if base_speaker_id not in self.processed_speakers_set: - self.speaker_data_accumulator[base_speaker_id].append( - {'embedding': item['embedding'], 'metadata': item['metadata']} - ) - self.total_accumulated_items += 1 - - expected_count = self.speaker_expected_counts[base_speaker_id] - current_count = len(self.speaker_data_accumulator[base_speaker_id]) - - if current_count == expected_count: - self.ready_to_flush_speaker_ids.add(base_speaker_id) - logger.debug( - f"Speaker {base_speaker_id} is complete with {current_count} items. Added to `ready_to_flush_speaker_ids`." - ) - elif current_count > expected_count: - msg = f"Speaker {base_speaker_id} has {current_count} items, but expected {expected_count}. Possible data inconsistency or error in expected counts." - logger.error(msg) - raise ValueError(msg) - else: - msg = f"Received new item for already processed speaker '{base_speaker_id}'. This may indicate issues with data sharding, expected counts, or duplicate data." - logger.error(msg) - raise ValueError(msg) - - if self.total_accumulated_items >= self.flush_threshold_items and self.ready_to_flush_speaker_ids: - self._process_and_flush_speakers_local() - - def _process_and_flush_speakers_local(self): - speakers_to_process_now = list(self.ready_to_flush_speaker_ids) - self.ready_to_flush_speaker_ids.clear() - - if not speakers_to_process_now: - msg = "_process_and_flush_speakers_local called, but `speakers_to_process_now` is empty after list conversion. This is unexpected." - logger.error(msg) - raise ValueError(msg) - - logger.info( - f"Flushing {len(speakers_to_process_now)} completed speakers. " - f"Current total accumulated items: {self.total_accumulated_items}" - ) - - for speaker_id in speakers_to_process_now: - speaker_items = self.speaker_data_accumulator.pop(speaker_id) - self.total_accumulated_items -= len(speaker_items) - self.processed_speakers_set.add(speaker_id) - - # Apply speaker size limit to prevent OOM while processing all items - all_items_to_process = speaker_items # We want to process ALL items - - # Create context pool with original indices for easy identification - if self.max_speaker_items and len(speaker_items) > self.max_speaker_items: - logger.warning( - f"Speaker {speaker_id} has {len(speaker_items)} items, exceeding max limit of {self.max_speaker_items}. " - f"Using random sample of {self.max_speaker_items} items as context pool, but processing all {len(speaker_items)} items." - ) - # Randomly sample with original indices preserved - random.seed(12345) # For reproducibility - indexed_items = [(idx, item) for idx, item in enumerate(speaker_items)] - sampled_indexed_items = random.sample(indexed_items, self.max_speaker_items) - context_pool_items = [item for _, item in sampled_indexed_items] - context_pool_original_indices = [idx for idx, _ in sampled_indexed_items] - else: - # Use all items as context pool - context_pool_items = speaker_items - context_pool_original_indices = list(range(len(speaker_items))) - - # NOTE: Now we compute similarities between ALL items and the context pool. - # This limits the similarity matrix to N×M instead of N×N where M <= max_speaker_items. - # Memory usage: N×M×4 bytes instead of N×N×4 bytes. - all_embeddings = torch.stack([item['embedding'] for item in all_items_to_process]) - context_embeddings = torch.stack([item['embedding'] for item in context_pool_items]) - - all_embeddings_norm = torch.nn.functional.normalize(all_embeddings, p=2, dim=1) - context_embeddings_norm = torch.nn.functional.normalize(context_embeddings, p=2, dim=1) - - # Compute N×M similarity matrix: each row is similarities for one item against all context candidates - similarity_matrix = torch.matmul(all_embeddings_norm, context_embeddings_norm.transpose(0, 1)) - - # Mask positions where items are identical (same item appearing in both N and M sets) - # Using original indices as identifiers. This prevents an item from being selected as its own context. - # Create a mapping from original indices to context pool positions - original_index_to_context_position = {} - for context_pos, original_idx in enumerate(context_pool_original_indices): - original_index_to_context_position[original_idx] = context_pos - - # Mask similarities for identical items - for n_idx in range(len(all_items_to_process)): - if n_idx in original_index_to_context_position: - context_pos = original_index_to_context_position[n_idx] - similarity_matrix[n_idx, context_pos] = MASKED_SIMILARITY_VALUE - - # Sort all similarities for each item to iterate through candidates - # sorted_similarities_tensor will contain sorted similarities for each row (original item) - # sorted_indices_tensor will contain indices in the context_pool - sorted_similarities_tensor, sorted_indices_tensor = torch.sort(similarity_matrix, dim=1, descending=True) - - num_records_written_for_speaker = 0 - # Initialize a counter for items discarded for this specific speaker - num_discarded_for_this_speaker_no_context = 0 - - for i, current_item_data in enumerate(all_items_to_process): - output_record = current_item_data['metadata'].copy() - write_this_record = False - - # Iterate through potential candidates from context pool, sorted by similarity - for candidate_rank in range(sorted_indices_tensor.size(1)): - candidate_ssim = sorted_similarities_tensor[i, candidate_rank].item() - context_pool_idx = sorted_indices_tensor[i, candidate_rank].item() - - # if ANY candidate has similarity ≤ MASKED_SIMILARITY_VALUE, all subsequent ones will be ≤ MASKED_SIMILARITY_VALUE - # since similarities are sorted in descending order, we can break early - if candidate_ssim <= MASKED_SIMILARITY_VALUE: - break - - # If SSIM is below threshold, stop searching for this item - if candidate_ssim < self.context_min_ssim: - break - - # Check duration if SSIM is acceptable - best_meta_dict = context_pool_items[context_pool_idx]['metadata'] - candidate_duration = best_meta_dict["duration"] - - if candidate_duration >= self.context_min_duration: - # Found a suitable candidate, update record and stop searching for this item - record_update_dict = { - "context_speaker_similarity": candidate_ssim, - "context_audio_filepath": best_meta_dict["audio_filepath"], - "context_audio_offset": best_meta_dict["offset"], - "context_audio_duration": candidate_duration, - "context_audio_text": best_meta_dict["text"], - } - normalized_text_candidate = best_meta_dict.get("normalized_text", None) - if normalized_text_candidate is not None: - record_update_dict["context_audio_normalized_text"] = normalized_text_candidate - - output_record.update(record_update_dict) - write_this_record = True - break - - if write_this_record: - self.output_manifest_file.write(json.dumps(output_record, ensure_ascii=False) + "\n") - num_records_written_for_speaker += 1 - else: - # This item will be discarded as no suitable context was found - num_discarded_for_this_speaker_no_context += 1 - - # Accumulate to rank-level total - self.total_discarded_no_suitable_context_this_rank += num_discarded_for_this_speaker_no_context - self.total_items_written_this_rank += num_records_written_for_speaker - - if len(speakers_to_process_now) > 0: - self.output_manifest_file.flush() # ensure all data currently held in the buffer is immediately written to disk. - logger.info(f"Flushing of completed speakers done. Local remaining items: {self.total_accumulated_items}") - - def on_predict_epoch_end(self): - logger.info( - f"Epoch end: Identifying remaining speakers to flush. " - f"Speakers in accumulator: {len(self.speaker_data_accumulator)}, Already processed: {len(self.processed_speakers_set)}" - ) - - for speaker_id, items in list(self.speaker_data_accumulator.items()): - if speaker_id not in self.processed_speakers_set: - expected_count = self.speaker_expected_counts[speaker_id] - actual_count = len(items) - if actual_count == expected_count: - logger.info( - f"Epoch end: Speaker {speaker_id} is complete ({actual_count}/{expected_count}). Adding to ready set." - ) - self.ready_to_flush_speaker_ids.add(speaker_id) - else: - msg = f"Epoch end: Speaker {speaker_id} is still in accumulator with {actual_count} items, but expected {expected_count}. This indicates an issue, e.g., not all data for this speaker was received or processed during the epoch." - logger.error(msg) - raise ValueError(msg) - - if self.ready_to_flush_speaker_ids: - logger.info( - f"Epoch end: Calling `_process_and_flush_speakers_local` for {len(self.ready_to_flush_speaker_ids)} ready speakers." - ) - self._process_and_flush_speakers_local() - else: - logger.info(f"Epoch end: No remaining speakers identified as ready to flush.") - - if self.speaker_data_accumulator: # Should be empty if all went well - msg = f"Epoch end: {len(self.speaker_data_accumulator)} speakers still in accumulator post-final flush attempt: {list(self.speaker_data_accumulator.keys())}" - logger.error(msg) - raise ValueError(msg) - - logger.info( - f"Total items discarded on this rank due to no suitable context found (failed SSIM or duration): {self.total_discarded_no_suitable_context_this_rank}" - ) - logger.info(f"Total items written to manifest on this rank: {self.total_items_written_this_rank}") - - # Verification step - expected_total_processed = ( - self.total_items_written_this_rank + self.total_discarded_no_suitable_context_this_rank - ) - if self.initial_assigned_count == expected_total_processed: - logger.info( - f"Verification successful: Initial items ({self.initial_assigned_count}) == Written ({self.total_items_written_this_rank}) + Discarded ({self.total_discarded_no_suitable_context_this_rank})" - ) - else: - msg = f"VERIFICATION FAILED: Initial items ({self.initial_assigned_count}) != Written ({self.total_items_written_this_rank}) + Discarded ({self.total_discarded_no_suitable_context_this_rank}) --- Difference: {self.initial_assigned_count - expected_total_processed}" - logger.error(msg) - raise RuntimeError(msg) - - if self.output_manifest_file: - self.output_manifest_file.close() - self.output_manifest_file = None - logger.info(f"Local processing complete. Partial manifest closed.") - - if torch.distributed.is_initialized(): - torch.distributed.barrier() # Wait for all ranks to finish writing files - - -def _parse_speaker_id_libritts(record): - """ - libritts format: audio_filepath = "{subset}/{speaker_id}/{chapter_id}/{speaker_id}_{chapter_id}_{utterance_id}_{segment_id}.wav" - e.g. "train-clean-100/89/218/89_218_000014_000003.wav" - re-organized speaker_id: "{subset}_{speaker_id}_{chapter_id}" - e.g. "train-clean-100_89_218" - """ - parts = record['audio_filepath'].lower().split('/') - return f"{parts[0]}_{parts[1]}_{parts[2]}" - - -def _parse_speaker_id_hifitts(record): - """ - hifitts format: audio_filepath = "{speaker_id}_{audio_quality}/{book_id}/{chapter_name}_{segment_id}.wav" - e.g. "11614_other/12352/prideofjennico_01_castle_0000.flac" - re-organized speaker_id: "{speaker_id}_{audio_quality}_{book_id}_{chapter_name}" - e.g. "11614_other_12352_prideofjennico_01_castle" - """ - parts = record['audio_filepath'].lower().split('/') - chapter_name = parts[-1].rsplit('_', 1)[0] - return f"{parts[0]}_{parts[1]}_{chapter_name}" - - -def _parse_speaker_id_hifitts2(record): - """ - hifitts2 format: audio_filepath = "{speaker_id}/{book_id}/{speaker_id}_{book_id}_{chapter_name}_{segment_id}.wav" - e.g. "100/2315/100_2315_sea_fairies_0812_librivox-01_baum_sea_fairies_0.flac" - re-organized speaker_id: "{speaker_id}_{book_id}_{chapter_name}" - e.g. "100_2315_sea_fairies_0812_librivox-01_baum_sea_fairies" - """ - parts = record['audio_filepath'].lower().split('/') - return parts[-1].rsplit('_', 1)[0] - - -def _parse_speaker_id_nvyt2505(record): - """ - nvyt2505 format: audio_filepath = "NVYT_40K_audios_wav/{utterance_id}.wav", which does not contain speaker_id. - e.g. "NVYT_40K_audios_wav/Thg50o7gmsk.wav" - But we can parse the speaker_id from: speaker = "| Language:en Dataset:NVYT_2505 Speaker:Thg50o7gmsk_SPEAKER_00 |". - re-organized speaker_id: "{parsed_speaker_id}" - e.g. "thg50o7gmsk_speaker_00" - """ - speaker_regex = re.compile(r'Speaker:([^ |]+)') - match = speaker_regex.search(record['speaker']) - if not match: - raise ValueError(f"Failed to parse speaker_id from record: {record}") - return match.group(1).lower() - - -def _parse_speaker_id_rivaLindyRodney(record): - """ - rivaLindyRodney format: audio_filepath = "{speaker}/44khz/{emotion}/{speaker}_{emotion}_{utterance_id}.wav" - e.g. "Lindy/44khz/WIZWIKI/LINDY_WIZWIKI_004161.wav" - re-organized speaker_id: "{speaker}_{emotion}" - e.g. "lindy_wizwiki" - """ - parts = record['audio_filepath'].lower().split('/') - return f"{parts[0]}_{parts[2]}" - - -def _parse_speaker_id_rivaEmmaMeganSeanTom(record): - """ - rivaEmmaMeganSeanTom format: audio_filepath = "{speaker}/22_kHz/{speaker}_{emotion}_{utterance_id}.wav" - e.g. "Emma/22_kHz/Emma_Sad_Intense_Correlated_00147.wav" - re-organized speaker_id: "{speaker}_{emotion}" - e.g. "emma_sad_intense_correlated" - """ - parts = record['audio_filepath'].lower().split('/') - return parts[2].rsplit('_', 1)[0] - - -def _parse_speaker_id_jhsdGtc20Amp20Keynote(record): - """ - jhsdGtc20Amp20Keynote format: audio_filepath = "{keynote_event}_KEYNOTE-VOOnly-44khz-16bit-mono_{utterance_id}.wav" - e.g. "AMP20_KEYNOTE-VOOnly-44khz-16bit-mono_12.wav" - re-organized speaker_id: "{keynote_event}" - e.g. "AMP20" - """ - return record['audio_filepath'].lower().rsplit('_', 2)[0] - - -def _get_parsed_speaker_id_for_dataset(dataset_name_arg, record): - """Routes to the appropriate speaker ID parsing function based on dataset_name.""" - if dataset_name_arg == "libritts": - return _parse_speaker_id_libritts(record) - elif dataset_name_arg == "librittsDevClean": - return _parse_speaker_id_libritts(record) - elif dataset_name_arg == "hifitts": - return _parse_speaker_id_hifitts(record) - elif dataset_name_arg == "hifitts2": - return _parse_speaker_id_hifitts2(record) - elif dataset_name_arg == "nvyt2505": - return _parse_speaker_id_nvyt2505(record) - elif dataset_name_arg == "rivaLindyRodney": - return _parse_speaker_id_rivaLindyRodney(record) - elif dataset_name_arg == "rivaEmmaMeganSeanTom": - return _parse_speaker_id_rivaEmmaMeganSeanTom(record) - elif dataset_name_arg == "jhsdGtc20Amp20Keynote": - return _parse_speaker_id_jhsdGtc20Amp20Keynote(record) - else: - logger.error( - f"Unsupported dataset_name '{dataset_name_arg}' provided. Please check the --dataset-name argument." - ) - raise ValueError(f"Unsupported dataset_name: {dataset_name_arg}") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--manifest", type=str, required=True) - parser.add_argument("--audio-base-dir", type=str, required=True) - parser.add_argument("--output-dir", type=str, required=True, help="Directory to save rank-specific manifests.") - parser.add_argument( - "--dataset-name", - type=str, - required=True, - choices=[ - "libritts", - "librittsDevClean", - "hifitts", - "hifitts2", - "nvyt2505", - "rivaLindyRodney", - "rivaEmmaMeganSeanTom", - "jhsdGtc20Amp20Keynote", - ], - help="Name of the dataset being processed. This determines the speaker ID parsing logic.", - ) - parser.add_argument("--flush-threshold-items", type=int, default=20000) - parser.add_argument( - "--context-min-duration", type=float, default=3.0, help="Minimum duration for a context audio segment." - ) - parser.add_argument( - "--context-min-ssim", type=float, default=0.6, help="Minimum cosine similarity for a context audio segment." - ) - parser.add_argument( - "--max-speaker-items", - type=int, - default=None, - help="Maximum size of context pool per speaker to prevent OOM. If a speaker has more items, a random sample will be used as context pool, but all items will still be processed. Default: None (no limit, potential OOM risk).", - ) - parser.add_argument("--devices", type=int, default=-1) - parser.add_argument("--num-nodes", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=16) - parser.add_argument("--num-workers", type=int, default=4) - parser.add_argument("--wandb-entity", type=str, default=None) - parser.add_argument("--wandb-project", type=str, default="speaker_similarity_sharded") - parser.add_argument("--wandb-name", type=str, default=None) - parser.add_argument( - "--log-level", - type=str, - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level.", - ) - args = parser.parse_args() - - logging.basicConfig( - level=getattr(logging, args.log_level.upper()), - format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - - # Initialize DDP early to get rank and world_size for sharding - # PyTorch Lightning Trainer will handle DDP initialization if not done explicitly, - # but we need rank/world_size for data sharding before Trainer setup. - ddp_env_vars_detected = "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ - - user_intended_distributed = False - if isinstance(args.devices, int) and args.devices not in [0, 1]: # 0 for CPU, 1 for single GPU. -1 means all GPUs. - user_intended_distributed = True - if args.num_nodes > 1: - user_intended_distributed = True - - if user_intended_distributed and not ddp_env_vars_detected: - logger.warning( - f"Warning: A distributed run seems intended (num_nodes={args.num_nodes}, devices='{args.devices}'), " - f"but standard DDP environment variables (e.g., `LOCAL_RANK`, `WORLD_SIZE`) were not detected pre-Trainer initialization. " - f"If launching on SLURM, ensure you are using `srun` or have correctly configured your sbatch script. " - f"For local multi-GPU, consider using `torchrun`. " - f"PyTorch Lightning will now attempt to initialize the distributed environment. " - f"If it defaults to a single process, data sharding will be ineffective (all data processed by one rank)." - ) - - strategy = ( - DDPStrategy(find_unused_parameters=False) - if (isinstance(args.devices, int) and args.devices != 1 and args.devices != 0) - else "auto" - ) - - trainer = Trainer( - devices=args.devices, - num_nodes=args.num_nodes, - accelerator="gpu", - strategy=strategy, - logger=None, - max_epochs=1, - use_distributed_sampler=False, - ) - - world_size = trainer.world_size - global_rank = trainer.global_rank - - log_format = f"%(asctime)s [RANK {global_rank}] [%(levelname)s] %(message)s" - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - logging.basicConfig(level=getattr(logging, args.log_level.upper()), format=log_format, datefmt="%Y-%m-%d %H:%M:%S") - - if global_rank == 0: - logger.info("Reading and sharding manifest ...") - - temp_sv_model_for_config = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( - 'titanet_large', map_location=torch.device('cpu') - ) - # Initialize sample_rate for all ranks; rank 0 will populate it. - # This variable will be broadcast if in distributed mode. - sample_rate = temp_sv_model_for_config.preprocessor._sample_rate - min_duration_in_sec_required = temp_sv_model_for_config.preprocessor.featurizer.hop_length * 2 / sample_rate - del temp_sv_model_for_config - logger.info( - f"Calculated sample_rate: {sample_rate}, min_duration_in_sec_required: {min_duration_in_sec_required:.3f}s" - ) - - speaker_to_records = defaultdict(list) - num_processed_records = 0 - total_initial_records = 0 - - with open(args.manifest, "r", encoding="utf-8") as f: - for line in f: - total_initial_records += 1 - try: - rec = json.loads(line.strip()) - except json.JSONDecodeError: - logger.warning(f"Skipping malformed JSON line: `{line.strip()}`") - continue - - # 1. Apply duration filter - if rec.get("duration") is None or rec.get("duration") < min_duration_in_sec_required: - continue - - # 2. Apply speaker format check - if not check_speaker_format(rec["speaker"]): - msg = f"Invalid speaker format for record: {rec['speaker']}, File: {rec['audio_filepath']}(offset={rec['offset']}, duration={rec['duration']})." - logger.error(msg) - raise ValueError(msg) - - # 3. Parse speaker ID and add to map - rec['parsed_speaker_id'] = _get_parsed_speaker_id_for_dataset(args.dataset_name, rec) - - speaker_to_records[rec['parsed_speaker_id']].append(rec) - num_processed_records += 1 - - num_filtered_out_initial_pass = total_initial_records - num_processed_records - - logger.info( - f"Initial pass filtered out {num_filtered_out_initial_pass} records (e.g., duration). Processing {num_processed_records} records before speaker count filter." - ) - - # Filter speakers with less than 2 segments - speakers_before_count_filter = len(speaker_to_records) - - speakers_with_segment_counts = [ - {"count": len(rec_list), "records": rec_list} - for _, rec_list in speaker_to_records.items() - if len(rec_list) >= 2 - ] - del speaker_to_records - - speakers_after_count_filter = len(speakers_with_segment_counts) - records_after_count_filter = sum(item["count"] for item in speakers_with_segment_counts) - - num_speakers_filtered_by_count = speakers_before_count_filter - speakers_after_count_filter - num_records_filtered_by_speaker_count = num_processed_records - records_after_count_filter - - logger.info( - f"Filtered out {num_speakers_filtered_by_count} speakers (and {num_records_filtered_by_speaker_count} corresponding records) with < 2 segments. " - f"Now processing {records_after_count_filter} records from {speakers_after_count_filter} speakers for sharding." - ) - - # Greedy Bin-Packing for speaker distribution - # 1. Sort speakers by segment count in descending order - speakers_with_segment_counts.sort(key=lambda x: x["count"], reverse=True) - - # 2. Initialize rank loads and assignments - rank_loads = [0] * world_size - rank_assignments = [[] for _ in range(world_size)] - - # 3. Assign speakers to ranks using greedy approach - for speaker_info in speakers_with_segment_counts: - # Find the rank with the minimum current load - min_load_rank_idx = rank_loads.index(min(rank_loads)) - - # Assign all records of this speaker to that rank - rank_assignments[min_load_rank_idx].extend(speaker_info["records"]) - # Update the load of that rank - rank_loads[min_load_rank_idx] += speaker_info["count"] - - data_to_distribute = rank_assignments - logger.info( - f"Sharding complete. {sum(len(r) for r in data_to_distribute)} records distributed among {world_size} ranks." - ) - for r, recs in enumerate(data_to_distribute): - logger.info(f"Plan for rank {r} = {len(recs)} records.") - - per_rank_speaker_counts = [] # [{"spk_0": 10, "spk_1": 5}, {"spk_2": 6, "spk_3": 8}, ...] - for rank_idx in range(world_size): - counts_for_rank = defaultdict(int) - for record in data_to_distribute[rank_idx]: - counts_for_rank[record['parsed_speaker_id']] += 1 - per_rank_speaker_counts.append(dict(counts_for_rank)) - - else: # Other ranks prepare to receive - data_to_distribute = [None] * world_size - per_rank_speaker_counts = [None] * world_size - sample_rate = None # Initialize for non-rank 0 before broadcast - - # Broadcast the list of lists of records. Each rank will then pick its part. - if world_size > 1 and not torch.distributed.is_initialized(): - logger.warning( - f"Distributed run (world_size={world_size}) detected, but `torch.distributed` not yet initialized. " - f"Attempting to trigger environment setup via `trainer.strategy.setup_environment()`." - ) - # The trainer's strategy is responsible for setting up the distributed environment. - # This typically happens implicitly during trainer.fit/predict/test/validate calls. - trainer.strategy.setup_environment() - if torch.distributed.is_initialized(): - logger.info( - f"`torch.distributed` successfully initialized after `trainer.strategy.setup_environment()`. Synchronizing ranks." - ) - torch.distributed.barrier() # Ensure all ranks have completed setup before proceeding. - else: - msg = f"[Rank {global_rank}] FATAL: Failed to initialize `torch.distributed` even after calling `trainer.strategy.setup_environment()` for world_size={world_size}. Cannot proceed with distributed data sharding." - logger.error(msg) - raise RuntimeError(msg) - elif world_size == 1 and torch.distributed.is_initialized(): - # This case should ideally not happen (DDP initialized for a single process run by Lightning). - logger.warning(f"Warning: `torch.distributed` is initialized, but world_size is 1. This is unusual.") - elif world_size > 1 and torch.distributed.is_initialized(): - logger.info(f"`torch.distributed` was already initialized. world_size={world_size}. Synchronizing ranks.") - torch.distributed.barrier() - - # Now, proceed with the data distribution logic, expecting `torch.distributed` to be initialized if world_size > 1. - my_speaker_expected_counts = {} - if torch.distributed.is_initialized(): - torch.distributed.broadcast_object_list(data_to_distribute, src=0) - assigned_records_for_this_rank = data_to_distribute[global_rank] - torch.distributed.broadcast_object_list(per_rank_speaker_counts, src=0) - my_speaker_expected_counts = per_rank_speaker_counts[global_rank] - - # Broadcast sample_rate - if global_rank == 0: - sample_rate_to_broadcast = [sample_rate] - else: - sample_rate_to_broadcast = [None] - torch.distributed.broadcast_object_list(sample_rate_to_broadcast, src=0) - sample_rate = sample_rate_to_broadcast[0] - logger.info(f"Received {len(assigned_records_for_this_rank)} records for processing.") - logger.debug(f"Expected speaker counts for this rank: {my_speaker_expected_counts}") - logger.info(f"Received sample_rate via broadcast: {sample_rate}") - elif world_size == 1: - # data_to_distribute is already prepared by rank 0 code block if world_size was 1 from start - assigned_records_for_this_rank = data_to_distribute[0] if data_to_distribute and data_to_distribute[0] else [] - my_speaker_expected_counts = ( - per_rank_speaker_counts[0] if per_rank_speaker_counts and per_rank_speaker_counts[0] else {} - ) - if not assigned_records_for_this_rank: - msg = f"[Rank {global_rank}] Error: No records were assigned for processing in single process mode. Issue in initial data prep." - logger.error(msg) - raise ValueError(msg) - logger.info(f"Single process, assigned {len(assigned_records_for_this_rank)} records.") - logger.debug(f"Expected speaker counts: {my_speaker_expected_counts}") - logger.info(f"Using sample_rate from rank 0 execution: {sample_rate}") - else: - msg = f"[Rank {global_rank}] Critical: DDP not initialized for sharding, and not a single process run. Cannot determine configuration." - logger.error(msg) - raise ValueError(msg) - - # Validate that sample_rate is now available on all ranks before use - if sample_rate is None: - msg = f"[Rank {global_rank}] Critical error: sample_rate was not correctly set or broadcasted. Value is None." - logger.error(msg) - raise RuntimeError(msg) - - wandb_logger = None - if args.wandb_entity and args.wandb_project and global_rank == 0: - run_name = args.wandb_name or f"sharded_similarity_{Path(args.manifest).stem}" - wandb_logger = WandbLogger( - project=args.wandb_project, entity=args.wandb_entity, name=run_name, log_model=False - ) - logger.info(f"Wandb logging enabled to {args.wandb_entity}/{args.wandb_project}, run name: {run_name}") - trainer.logger = wandb_logger - - dataset = SpeakerShardedAudioDataset( - assigned_records_list=assigned_records_for_this_rank, - base_audio_dir=args.audio_base_dir, - sample_rate=sample_rate, - ) - - dataloader = DataLoader( - dataset, - batch_size=args.batch_size, - num_workers=args.num_workers, - shuffle=False, - collate_fn=dataset.collate_fn, - pin_memory=True, - ) - - model = EmbeddingSimilarityExtractorSharded( - output_dir=args.output_dir, - output_file_prefix=Path(args.manifest).stem, - flush_threshold_items=args.flush_threshold_items, - context_min_duration=args.context_min_duration, - context_min_ssim=args.context_min_ssim, - speaker_expected_counts_map=my_speaker_expected_counts, - initial_assigned_count=len(assigned_records_for_this_rank), - max_speaker_items=args.max_speaker_items, - ) - logger.info( - f"Starting prediction with {len(assigned_records_for_this_rank)} records ({len(my_speaker_expected_counts)} unique speakers for this rank according to counts)." - ) - trainer.predict(model, dataloaders=dataloader) - - # Rank 0 merges the partial manifests - if global_rank == 0: - final_manifest_path = Path(args.output_dir) / ( - Path(args.manifest).stem - + f"_withContextAudioMinDur{args.context_min_duration}MinSSIM{args.context_min_ssim}.json" - ) - logger.info(f"Merging partial manifest files to `{final_manifest_path}`...") - with open(final_manifest_path, "w", encoding="utf-8") as final_out_f: - for i in range(world_size): - partial_file_path = Path(args.output_dir) / f"{Path(args.manifest).stem}_rank{i}.json" - if partial_file_path.exists(): - with open(partial_file_path, "r", encoding="utf-8") as pf: - for line in pf: - final_out_f.write(line) - logger.info(f"Merged `{partial_file_path}`") - else: - logger.warning(f"Warning - partial manifest file not found: `{partial_file_path}`") - logger.info(f"Merging complete. Final manifest: `{final_manifest_path}`") - - if wandb_logger and global_rank == 0: - wandb.finish() - logger.info("WandB run finished.") - - logger.info(f"Done.") - - -if __name__ == "__main__": - main() diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index cbc0c452a4c7..001c8b307883 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Inference and Evaluation script used for CI and NeMo model evaluation on custom datasets. +Please use this script as an example of how to do inference with MagpieTTS, but this script is otherwise unsupported +for general use cases. +""" import argparse import copy import glob diff --git a/scripts/tts_dataset_to_lhotse/README.md b/scripts/tts_dataset_to_lhotse/README.md deleted file mode 100644 index 436cbdbc4d20..000000000000 --- a/scripts/tts_dataset_to_lhotse/README.md +++ /dev/null @@ -1,82 +0,0 @@ -# Everything Speech Data - -### Single turn speech to speech data - -Our single turn speech to speech data is in the form of conversations such that it will be easy to extend to multi-turn conversations. In this section we will go through the following: - -- Raw manifest format -- Lhotse cuts and shar format -- Creating Lhotse shar and cuts from raw manifest - -#### Raw manifest format - -Users need to get their manifests in the following format for `scripts/speech_data_generation/create_shar.py` to work. Each datapoint in the manifest should be in this format: - -``` -{ - 'sample_id': '', - 'normalized_answer_wer': (Optional, this is useful if speech data was synthesized and we want to filter based on wer of the synthesized speech), - 'normalized_answer_cer': (Optional, this is useful if speech data was synthesized and we want to filter based on cer of the synthesized speech), - 'conversations': [ - {'value': '', - 'from': 'user', - 'type': 'audio', - 'duration': (Optional | during shar creation it is automatically calculated), - 'lang': (Optional), - 'instruction': ''(This field is needed for s2s and in direct_s2s this is not needed) - }, - {'value': '', - 'from': 'agent', - 'type': 'audio', - 'duration': (Optional | during shar creation it is automatically calculated), - 'lang': (Optional), - 'transcript': '' - } - ] -} -``` - -#### Lhotse cuts and shar format - -There will be 3 types of files generated after you run `scripts/speech_data_generation/create_shar.py`: - -- cuts.{some_number}.jsonl.gz -- recording.{some_number}.tar -- target_audio.{some_number}.tar - -**recording.{some_number}.tar** - tarred user (input) speech wav files - -**target_audio.{some_number}.tar** - tarred agent (target) speech wav files - -**cuts.{some_number}.jsonl.gz** - You can think of this as the Lhotse manifest. The format or the fields are explained as below (This document will only go over the fields which are used during training/inference) - -This is what a typical cut would look like, which is one datapoint in any of the cuts.{some_number}.jsonl.gz files: -``` -MonoCut(id='squadv2_5705e3a452bb891400689658-2', start=0, duration=17.345306122448978, channel=0, supervisions=[SupervisionSegment(id='squadv2_5705e3a452bb891400689658-2', recording_id='squadv2_5705e3a452bb891400689658-2', start=0, duration=17.345306122448978, channel=0, text='Transcribe and answer:', language='EN', speaker='user', gender=None, custom=None, alignment=None), SupervisionSegment(id='squadv2_5705e3a452bb891400689658-2', recording_id='squadv2_5705e3a452bb891400689658-2', start=0, duration=1.1493877551020408, channel=0, text='NCT of Delhi', language='EN', speaker='agent', gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='squadv2_5705e3a452bb891400689658', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=44100, num_samples=764928, duration=17.345306122448978, channel_ids=[0], transforms=None), custom={'target_audio': Recording(id='_lustre_fsw_portfolios_llmservice_projects_llmservice_nemo_speechlm_data_speech_QA_outputs_speechall_squadv2_train_normalized___audios_squadv2_5705e3a452bb891400689658_synthesized_normalized_answer_audio', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=22050, num_samples=25344, duration=1.1493877551020408, channel_ids=[0], transforms=None), 'shard_origin': PosixPath('/lustre/fs7/portfolios/llmservice/projects/llmservice_nemo_speechlm/data/s2s_synthetic_data/s2s_lhotse_with_wavs/test2/cuts.000000.jsonl.gz'), 'shar_epoch': 0}) -``` - -Explaination of the fields: - -- id: Unique to the datapoint, this is used to reference recording wavs from the corresponding tarred files -- duration: This is the duration of the recording or source audio -- supervisions: Is a list of 2 elements (1 user turn and 1 agent turn) containing the metadata related to each turn. -- supervisions[0].text: Instruction from the user -- supervisions[0].speaker: user -- supervisions[0].language: language of input audio -- supervisions[1].text: transcript of target audio -- supervisions[1].speaker: agent -- supervisions[1].language: language of target audio -- custom['target_audio'] - This is the agent or target audio also in the form of a Recording. It has it's own duration, sampling_rate and id -- custom['target_audio'].id - is used to reference target_audios from target_audio tar file. -- custom['target_audio'].duration - self explainatory -- custom['target_audio'].sampling_rate - self explainatory - -#### Creating Lhotse shar and cuts from raw manifest - -To create Lhotse shar and cuts from raw manifests simple run the following command: -``` -python scripts/speech_to_speech_data_generation/create_shars.py \ ---manifest= \ ---out_shar_dir= \ ---num_shard= -``` \ No newline at end of file diff --git a/scripts/tts_dataset_to_lhotse/create_shars.py b/scripts/tts_dataset_to_lhotse/create_shars.py deleted file mode 100644 index e017e59acbf3..000000000000 --- a/scripts/tts_dataset_to_lhotse/create_shars.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import json -import os -from pathlib import Path - -### from nemo.collections.tts.models import AudioCodecModel -from lhotse import CutSet, Recording, SupervisionSegment -from lhotse.audio import RecordingSet -from lhotse.shar.writers import AudioTarWriter -from tqdm import tqdm - - -def json_reader(filename): - with open(filename) as f: - for line in f: - yield json.loads(line) - - -def create_shar_from_manifest(manifest, audio_root_path, out_shar_dir, shard_size=1000): - in_manifest = list(json_reader(manifest)) - print(f"...loaded {manifest} # of datapoints {len(in_manifest)}") - num_shard = int(len(in_manifest) / shard_size) - if len(in_manifest) % shard_size != 0: - shard_size += 1 - print(f"shard_size {shard_size} num_shards {num_shard}") - - user_recordings = [] - answer_list = [] - instructions = [] - source_language = [] - target_language = [] - target_recordings = [] - for i, line in tqdm(enumerate(in_manifest)): - # For single turn convs is a list of 2 elements - # First element is user speech and second is agent speech - - # User_Speech - context_audio_path = line["context_audio_filepath"] - - user_recording = Recording.from_file(os.path.join(audio_root_path, context_audio_path)) - user_recordings.append(user_recording) - - # This are the context text, this could be different things like a simple instruction or details about speaker voice - instructions.append(" ") - - # Language source - if "lang" in line: - language = line["lang"] - elif "language" in line: - language = line["language"] - elif "Language:" in str(line["speaker"]): - language = line["speaker"].split("Language:")[1].split(" ")[0] - else: - language = "en" - - source_language.append(language) - - # Loading agent audio and using only the extracted features as nd.array - target_recordings.append(Recording.from_file(os.path.join(audio_root_path, line["audio_filepath"]))) - # Agent answer transcript - answer_list.append(line["text"]) - # Language target - target_language.append(language) - - print("Done extracting data from manifest") - print(len(user_recordings)) - cuts = CutSet.from_manifests(recordings=RecordingSet.from_recordings(user_recordings)) - - # Attach text - for i, cut in tqdm(enumerate(cuts)): - cut.supervisions.append( - SupervisionSegment( - id=cut.id, - recording_id=cut.id, - start=0, - duration=cut.recording.duration, - text=instructions[i], - speaker="user", - language=source_language[i].upper(), - ), - ) - cut.supervisions.append( - SupervisionSegment( - id=cut.id, - recording_id=cut.id, - start=0, - duration=target_recordings[i].duration, - text=answer_list[i], - speaker="agent", - language=target_language[i].upper(), - ), - ) - cut.target_audio = target_recordings[i] - - print("...Making Shars") - out_shar_dir = Path(out_shar_dir) - out_shar_dir.mkdir(parents=True, exist_ok=True) - assert len(user_recordings) % shard_size != 0, "Lhotse breaks if feat_list is a multiple of shard_size" - exported = cuts.to_shar(out_shar_dir, fields={"recording": "wav"}, num_jobs=4, shard_size=shard_size) - print(f"...share created") - - print(f"...Exporting target_audio to tar files") - for i, path in tqdm(enumerate(exported["cuts"])): - path = path[0] - out_path = path.replace("cuts", "target_audio").replace(".jsonl.gz", ".tar") - with AudioTarWriter(out_path, shard_size=None, format="flac") as writer: - for cut in CutSet.from_file(path): - writer.write(cut.id, cut.target_audio.load_audio(), manifest=cut.target_audio, sampling_rate=22050) - print(f"...Exported target_audio to tar files") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--manifest', - type=str, - default="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/manifests/hifitts__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5.json", - ) - parser.add_argument( - '--audio_root_path', - type=str, - default="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/hi_fi_tts_v0/", - ) - parser.add_argument( - '--out_shar_dir', - type=str, - default="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/hifitts/", - ) - parser.add_argument( - '--shard_size', - type=int, - default=1000, - ) - - args = parser.parse_args() - print(f"manifest {args.manifest}") - print(f"audio_root_path {args.audio_root_path}") - print(f"out_shar_dir {args.out_shar_dir}") - print(f"num_shard {args.shard_size}") - - create_shar_from_manifest( - manifest=args.manifest, - audio_root_path=args.audio_root_path, - out_shar_dir=args.out_shar_dir, - shard_size=args.shard_size, - ) - - -if __name__ == "__main__": - main() From 8c732c5fcd18de10fb93a579900b0415ee314ae7 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 4 Nov 2025 09:19:19 -0800 Subject: [PATCH 112/113] CodeQL and Lint fixes Signed-off-by: Jason --- nemo/collections/tts/models/magpietts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 942e639a641c..cc6c84a234af 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -29,7 +29,6 @@ from torch import nn from torch.utils.data import get_worker_info -import nemo.collections.asr as nemo_asr from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss @@ -1462,6 +1461,7 @@ def prepare_context_tensors(self, batch): elif self.model_type in ['decoder_context_tts', 'decoder_ce']: dec_context_size = context_mask.size(1) + context_embeddings = None # Address CodeQL if self.model_type == 'decoder_context_tts': context_embeddings = context_input_embedded elif self.model_type == 'decoder_ce': From 1acbc8ffe36550bc4cd11ef6416aa889ef908609 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 4 Nov 2025 11:24:06 -0800 Subject: [PATCH 113/113] Update confs and readmes Signed-off-by: Jason --- .../{magpietts_en.yaml => magpietts.yaml} | 44 +- .../tts/conf/magpietts/magpietts_dc_en.yaml | 174 ------- .../magpietts_inference_multilingual_v1.yaml | 224 --------- ...hotse_dc_en.yaml => magpietts_lhotse.yaml} | 73 ++- .../magpietts_lhotse_dc_en_tiny.yaml | 194 -------- .../magpietts/magpietts_multilingual_v1.yaml | 234 --------- .../magpietts_multilingual_v2_lhotse.yaml | 262 ---------- ...ce_en.yaml => magpietts_po_inference.yaml} | 56 ++- examples/tts/magpietts.py | 2 +- examples/tts/t5tts_commands.md | 462 ------------------ scripts/magpietts/README_magpie_po.md | 10 +- 11 files changed, 115 insertions(+), 1620 deletions(-) rename examples/tts/conf/magpietts/{magpietts_en.yaml => magpietts.yaml} (80%) delete mode 100644 examples/tts/conf/magpietts/magpietts_dc_en.yaml delete mode 100644 examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml rename examples/tts/conf/magpietts/{magpietts_lhotse_dc_en.yaml => magpietts_lhotse.yaml} (75%) delete mode 100644 examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml delete mode 100644 examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml delete mode 100644 examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml rename examples/tts/conf/magpietts/{magpietts_inference_en.yaml => magpietts_po_inference.yaml} (76%) delete mode 100644 examples/tts/t5tts_commands.md diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts.yaml similarity index 80% rename from examples/tts/conf/magpietts/magpietts_en.yaml rename to examples/tts/conf/magpietts/magpietts.yaml index a68f53cba5d1..4c45f38fb4b3 100644 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ b/examples/tts/conf/magpietts/magpietts.yaml @@ -1,4 +1,4 @@ -name: Magpie-TTS-EN +name: Magpie-TTS max_epochs: ??? # Adjust batch size based on GPU memory @@ -8,21 +8,20 @@ batch_size: 16 weighted_sampling_steps_per_epoch: null # Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +# See DatasetMeta in https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/tts/data/text_to_speech_dataset.py train_ds_meta: ??? val_ds_meta: ??? model: - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts - context_duration_min: 3.0 + model_type: "decoder_ce" # decoder_context_tts or decoder_ce + use_text_conditioning_encoder: true # Enable or disable text context. Audio context is always enabled + text_conditioning_tokenizer_name: text_ce_tokenizer # The tokenizer to be used for text contexts + context_duration_min: 5.0 context_duration_max: 5.0 load_cached_codes_if_available: true prior_scaling_factor: 0.5 prior_end_step: 12000 - prior_scaledown_start_step: 8000 # Prior will always be on before this step. + prior_scaledown_start_step: 8000 indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. alignment_loss_scale: 0.002 embedding_dim: 768 @@ -51,14 +50,14 @@ model: local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 3 + local_transformer_n_layers: 1 local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. - text_tokenizers: # Add more languages for multi-lingual TTS + text_tokenizers: english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer punct: true @@ -72,6 +71,13 @@ model: ignore_ambiguous_words: false use_chars: true use_stresses: true + text_ce_tokenizer: # Used for text context + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + ### For additional languages, consider adding a generic byt5 tokenizer like the one below + # french_chartokenizer: # Used for text context + # _target_: AutoTokenizer + # pretrained_model: "google/byt5-small" train_ds: dataset: @@ -113,8 +119,8 @@ model: max_length_causal_mask: 2048 use_learnable_pos_emb: true - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 + context_encoder: # Only used for decoder_ce (and multi_encoder_context_tts), ignored otherwise + n_layers: 1 d_model: 768 d_ffn: 3072 sa_n_heads: 12 @@ -136,6 +142,7 @@ model: p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: true + xa_d_head: 128 xa_d_memory: 768 xa_n_heads: 1 is_causal: true @@ -143,10 +150,11 @@ model: apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true + make_prior_window_strict: true optim: _target_: torch.optim.AdamW - lr: 1e-4 + lr: 2e-4 sched: name: ExponentialLR @@ -157,7 +165,7 @@ trainer: devices: -1 accelerator: gpu strategy: ddp_find_unused_parameters_true - precision: bf16-mixed + precision: 32 max_epochs: ${max_epochs} accumulate_grad_batches: 1 enable_checkpointing: False # Provided by exp_manager @@ -174,9 +182,11 @@ exp_manager: create_tensorboard_logger: true create_wandb_logger: false wandb_logger_kwargs: - name: null + entity: null project: null - resume: true + group: null + name: ${name} + resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. create_checkpoint_callback: true checkpoint_callback_params: monitor: val_loss @@ -184,6 +194,6 @@ exp_manager: save_top_k: 5 save_best_model: true always_save_nemo: true - filename: "${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}" + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' resume_if_exists: true resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_dc_en.yaml deleted file mode 100644 index 8765af931139..000000000000 --- a/examples/tts/conf/magpietts/magpietts_dc_en.yaml +++ /dev/null @@ -1,174 +0,0 @@ -name: Magpie-TTS-EN - -max_epochs: ??? -# Adjust batch size based on GPU memory -batch_size: 16 -# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. -# If null, then weighted sampling is disabled. -weighted_sampling_steps_per_epoch: null - -# Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 -train_ds_meta: ??? -val_ds_meta: ??? - -model: - model_type: "decoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. - context_duration_min: 5.0 - context_duration_max: 5.0 - load_cached_codes_if_available: true - prior_scaling_factor: 0.5 - prior_end_step: 12000 - prior_scaledown_start_step: 8000 # Prior will always be on before this step. - indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. - alignment_loss_scale: 0.002 - embedding_dim: 768 - codecmodel_path: ??? - max_epochs: ${max_epochs} - steps_per_epoch: ${weighted_sampling_steps_per_epoch} - cfg_unconditional_prob: 0.1 - - # Alignment encoder parameters, to binarize the prior - # This is used for attention-constrained training and inference - use_alignment_encoder: false - # Below args are only relevant if use_alignment_encoder is true - use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder - alignment_encoder_loss_scale: 1.0 - binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. - binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. - prior_future_context: 2 # Future window of the binarized prior. - prior_past_context: 2 # Past window of the binarized prior. - prior_future_decay: 0.8 # Decay factor for future context - prior_past_decay: 0.5 # Decay factor for past context - binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs - binarized_prior_epsilon: 0.0 - aligner_encoder_train_steps: 50000 - - # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" - # Below args are only relevant if use_local_transformer is true - local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 3 - local_transformer_n_heads: 1 - local_transformer_hidden_dim: 256 - - text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. - text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. - - text_tokenizers: # Add more languages for multi-lingual TTS - english_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - punct: true - apostrophe: true - pad_with_space: false - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" - heteronyms: "scripts/tts_dataset_files/heteronyms-052722" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - - train_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${train_ds_meta} - weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} - min_duration: 0.2 - max_duration: 20.0 - - dataloader_params: - batch_size: ${batch_size} - num_workers: 4 - drop_last: true - pin_memory: true - - validation_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${val_ds_meta} - min_duration: 0.2 - max_duration: 20.0 - - dataloader_params: - batch_size: ${batch_size} - num_workers: 4 - pin_memory: true - - encoder: - n_layers: 6 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - decoder: - n_layers: 12 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 1 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: true - xa_d_memory: 768 - xa_n_heads: 1 - xa_d_head: 128 - is_causal: true - apply_norm_to_cond: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - optim: - _target_: torch.optim.AdamW - lr: 1e-4 - - sched: - name: ExponentialLR - gamma: 0.998 - -trainer: - num_nodes: 1 - devices: -1 - accelerator: gpu - strategy: ddp_find_unused_parameters_true - precision: bf16-mixed - max_epochs: ${max_epochs} - accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager - log_every_n_steps: 100 - check_val_every_n_epoch: 1 - num_sanity_val_steps: 0 - benchmark: false - gradient_clip_val: 2.5 - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_wandb_logger: false - wandb_logger_kwargs: - name: null - project: null - resume: true - create_checkpoint_callback: true - checkpoint_callback_params: - monitor: val_loss - mode: min - save_top_k: 5 - save_best_model: true - always_save_nemo: true - filename: "${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}" - resume_if_exists: true - resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml deleted file mode 100644 index 6e8bbf47a7d1..000000000000 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ /dev/null @@ -1,224 +0,0 @@ -name: Magpie-TTS-ML-V1-Infer -mode: test -init_from_ptl_ckpt: ??? -max_epochs: 1 -# Adjust batch size based on GPU memory -batch_size: 16 -# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. -# If null, then weighted sampling is disabled. -weighted_sampling_steps_per_epoch: null - -# Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 -test_ds_meta: ??? - -phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" -heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" -model: - use_kv_cache_for_inference: true - inference_temperature: 0.7 - inference_topk: 80 - inference_use_cfg: false - inference_cfg_scale: 1.0 - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts - context_duration_min: 3.0 - context_duration_max: 8.0 - speaker_emb_dim: 192 - max_decoder_steps: 500 - load_cached_codes_if_available: true - prior_scaling_factor: null - prior_end_step: 0 - prior_scaledown_start_step: 0 - alignment_loss_scale: 0.0 - embedding_dim: 768 - codecmodel_path: null - max_epochs: ${max_epochs} - steps_per_epoch: ${weighted_sampling_steps_per_epoch} - - # Alignment encoder parameters, to binarize the prior - # This is used for attention-constrained training and inference - use_alignment_encoder: false - # Below args are only relevant if use_alignment_encoder is true - use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder - alignment_encoder_loss_scale: 1.0 - binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. - binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. - prior_future_context: 2 # Future window of the binarized prior. - prior_past_context: 2 # Past window of the binarized prior. - prior_future_decay: 0.8 # Decay factor for future context - prior_past_decay: 0.5 # Decay factor for past context - binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs - binarized_prior_epsilon: 0.0 - aligner_encoder_train_steps: 50000 - - # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" - # Below args are only relevant if use_local_transformer is true - local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 3 - local_transformer_n_heads: 1 - local_transformer_hidden_dim: 256 - - text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. - text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. - - text_tokenizers: # Add more languages for multi-lingual TTS - english_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - punct: true - apostrophe: true - pad_with_space: false - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" - heteronyms: "scripts/tts_dataset_files/heteronyms-052722" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - spanish_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: es-ES - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: es-ES - phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - german_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: de-DE - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: 'de-DE' - phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" - heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - grapheme_case: mixed - grapheme_prefix: '#' - mandarin_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p - phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" - word_segmenter: "jieba" - phoneme_prefix: "" - phoneme_case: "lower" - tone_prefix: "#" - ascii_letter_prefix: "" - ascii_letter_case: "upper" - - test_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${test_ds_meta} - min_duration: 0.5 - max_duration: 20.0 - - dataloader_params: - batch_size: ${batch_size} - num_workers: 2 - - encoder: - n_layers: 6 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: False - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: false - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - decoder: - n_layers: 12 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: true - xa_d_memory: 768 - xa_n_heads: 12 - is_causal: true - apply_norm_to_cond: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - optim: - _target_: torch.optim.Adam - lr: 2e-4 - betas: [0.8, 0.99] - - sched: - name: ExponentialLR - gamma: 0.998 - -trainer: - num_nodes: 1 - devices: -1 - accelerator: gpu - strategy: ddp_find_unused_parameters_true - precision: 32 - max_epochs: ${max_epochs} - accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager - log_every_n_steps: 100 - val_check_interval: 500 - benchmark: false - gradient_clip_val: 2.5 - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_wandb_logger: false - wandb_logger_kwargs: - name: null - project: null - create_checkpoint_callback: true - checkpoint_callback_params: - monitor: val_loss - mode: min - save_top_k: 5 - save_best_model: true - always_save_nemo: true - resume_if_exists: true - resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml b/examples/tts/conf/magpietts/magpietts_lhotse.yaml similarity index 75% rename from examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml rename to examples/tts/conf/magpietts/magpietts_lhotse.yaml index a21e32063dc8..241fddfff232 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse.yaml @@ -1,22 +1,22 @@ -name: MagpieTTS-EN-Lhotse +name: Magpie-TTS quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. - model: use_lhotse: true - model_type: "decoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. + model_type: "decoder_ce" # decoder_context_tts or decoder_ce + use_text_conditioning_encoder: true # Enable or disable text context. Audio context is always enabled + text_conditioning_tokenizer_name: text_ce_tokenizer # The tokenizer to be used for text contexts context_duration_min: 5.0 context_duration_max: 5.0 load_cached_codes_if_available: true prior_scaling_factor: 0.5 - prior_end_step: 12_000 - prior_scaledown_start_step: 8_000 # Prior will always be on before this step. - indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. + prior_end_step: 12000 + prior_scaledown_start_step: 8000 + indefinite_prior_prob: 0. # If > 0, then prior will be applied after prior_end_step with this probability. alignment_loss_scale: 0.002 embedding_dim: 768 codecmodel_path: ??? - cfg_unconditional_prob: 0.1 # enable classifier-free guidance during traing by dropping out conditionals with this probability + cfg_unconditional_prob: 0.1 # Alignment encoder parameters, to binarize the prior # This is used for attention-constrained training and inference @@ -24,7 +24,7 @@ model: # Below args are only relevant if use_alignment_encoder is true use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder alignment_encoder_loss_scale: 1.0 - binarize_prior_after_step: 10_000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. prior_future_context: 2 # Future window of the binarized prior. prior_past_context: 2 # Past window of the binarized prior. @@ -32,20 +32,20 @@ model: prior_past_decay: 0.5 # Decay factor for past context binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs binarized_prior_epsilon: 0.0 - aligner_encoder_train_steps: 50_000 + aligner_encoder_train_steps: 50000 # Local transformer parameters for autoregressive codebook prediction within a frame local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 3 + local_transformer_n_layers: 1 local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. - text_tokenizers: # Add more languages for multi-lingual TTS + text_tokenizers: english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer punct: true @@ -59,6 +59,13 @@ model: ignore_ambiguous_words: false use_chars: true use_stresses: true + text_ce_tokenizer: # Used for text context + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + ### For additional languages, consider adding a generic byt5 tokenizer like the one below + # french_chartokenizer: # Used for text context + # _target_: AutoTokenizer + # pretrained_model: "google/byt5-small" train_ds: use_lhotse: ${model.use_lhotse} @@ -71,14 +78,14 @@ model: batch_duration : ??? # in seconds. Adjust based on your GPU memory. quadratic_duration: ${quadratic_duration} use_bucketing: true - num_buckets: 10 + num_buckets: 20 bucket_buffer_size: 20_000 shuffle_buffer_size: 20_000 num_cuts_for_bins_estimate: 20_000 shard_seed: "trng" drop_last: true shuffle: true - num_workers: 4 + num_workers: 6 pin_memory: true input_cfg: @@ -88,6 +95,7 @@ model: tags: tokenizer_names: ["english_phoneme"] + validation_ds: use_lhotse: ${model.use_lhotse} volume_norm: true @@ -102,7 +110,7 @@ model: force_finite: true drop_last: false shuffle: false - num_workers: 4 + num_workers: 2 pin_memory: true input_cfg: @@ -115,7 +123,7 @@ model: encoder: n_layers: 6 d_model: 768 - d_ffn: 3_072 + d_ffn: 3072 sa_n_heads: 12 kernel_size: 3 p_dropout: 0.1 @@ -123,30 +131,45 @@ model: has_xattn: false is_causal: true apply_norm_out: true - max_length_causal_mask: 2_048 + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + context_encoder: # Only used for decoder_ce (and multi_encoder_context_tts), ignored otherwise + n_layers: 1 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: false + apply_norm_out: true + max_length_causal_mask: 2048 use_learnable_pos_emb: true decoder: n_layers: 12 d_model: 768 - d_ffn: 3_072 + d_ffn: 3072 sa_n_heads: 12 kernel_size: 1 p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: true + xa_d_head: 128 xa_d_memory: 768 xa_n_heads: 1 - xa_d_head: 128 is_causal: true apply_norm_to_cond: true apply_norm_out: true - max_length_causal_mask: 2_048 + max_length_causal_mask: 2048 use_learnable_pos_emb: true + make_prior_window_strict: true optim: _target_: torch.optim.AdamW - lr: 1e-4 + lr: 2e-4 sched: name: ExponentialLR @@ -157,12 +180,13 @@ trainer: devices: -1 accelerator: gpu strategy: ddp_find_unused_parameters_true - precision: bf16-mixed + precision: 32 max_steps: ??? accumulate_grad_batches: 1 - enable_checkpointing: false # Provided by exp_manager - logger: false # Provided by exp_manager + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager log_every_n_steps: 100 + check_val_every_n_epoch: 1 limit_train_batches: 1_000 val_check_interval: 1_000 num_sanity_val_steps: 0 @@ -170,7 +194,6 @@ trainer: use_distributed_sampler: false # required because Lhotse has its own handling gradient_clip_val: 2.5 - exp_manager: exp_dir: null name: ${name} diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml deleted file mode 100644 index 714a72f4ffa5..000000000000 --- a/examples/tts/conf/magpietts/magpietts_lhotse_dc_en_tiny.yaml +++ /dev/null @@ -1,194 +0,0 @@ -name: MagpieTTS-EN-Lhotse - -quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. - -model: - use_lhotse: true - model_type: "decoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. - context_duration_min: 5.0 - context_duration_max: 5.0 - load_cached_codes_if_available: true - prior_scaling_factor: 0.5 - prior_end_step: 12_000 - prior_scaledown_start_step: 8_000 # Prior will always be on before this step. - indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. - alignment_loss_scale: 0.002 - embedding_dim: 512 - codecmodel_path: ??? - cfg_unconditional_prob: 0.1 # enable classifier-free guidance during traing by dropping out conditionals with this probability - - # Alignment encoder parameters, to binarize the prior - # This is used for attention-constrained training and inference - use_alignment_encoder: false - # Below args are only relevant if use_alignment_encoder is true - use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder - alignment_encoder_loss_scale: 1.0 - binarize_prior_after_step: 10_000 # Switch from beta-binomial prior to binarized prior after this step. - binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. - prior_future_context: 2 # Future window of the binarized prior. - prior_past_context: 2 # Past window of the binarized prior. - prior_future_decay: 0.8 # Decay factor for future context - prior_past_decay: 0.5 # Decay factor for past context - binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs - binarized_prior_epsilon: 0.0 - aligner_encoder_train_steps: 50_000 - - # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" - # Below args are only relevant if use_local_transformer is true - local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 3 - local_transformer_n_heads: 1 - local_transformer_hidden_dim: 256 - - text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. - text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. - - text_tokenizers: # Add more languages for multi-lingual TTS - english_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - punct: true - apostrophe: true - pad_with_space: false - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" - heteronyms: "scripts/tts_dataset_files/heteronyms-052722" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - - train_ds: - use_lhotse: ${model.use_lhotse} - volume_norm: true - - dataset: - min_duration: 0.2 - min_context_speaker_similarity: 0.6 - max_cer: 0.03 - batch_duration : ??? # in seconds. Adjust based on your GPU memory. - quadratic_duration: ${quadratic_duration} - use_bucketing: true - num_buckets: 10 - bucket_buffer_size: 20_000 - shuffle_buffer_size: 20_000 - num_cuts_for_bins_estimate: 20_000 - shard_seed: "trng" - drop_last: true - shuffle: true - num_workers: 4 - pin_memory: true - - input_cfg: - - type: lhotse_shar - shar_path: ??? - weight: 1.0 - tags: - tokenizer_names: ["english_phoneme"] - - validation_ds: - use_lhotse: ${model.use_lhotse} - volume_norm: true - - dataset: - min_duration: 0.2 - min_context_speaker_similarity: 0.6 - max_cer: 0.03 - batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. - quadratic_duration: ${quadratic_duration} - use_bucketing: false - force_finite: true - drop_last: false - shuffle: false - num_workers: 4 - pin_memory: true - - input_cfg: - - type: lhotse_shar - shar_path: ??? - weight: 1.0 - tags: - tokenizer_names: ["english_phoneme"] - - encoder: - n_layers: 6 - d_model: 512 - d_ffn: 2048 - sa_n_heads: 8 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: true - apply_norm_out: true - max_length_causal_mask: 2_048 - use_learnable_pos_emb: true - - decoder: - n_layers: 12 - d_model: 512 - d_ffn: 2048 - sa_n_heads: 8 - kernel_size: 1 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: true - xa_d_memory: 512 - xa_n_heads: 1 - xa_d_head: 128 - is_causal: true - apply_norm_to_cond: true - apply_norm_out: true - max_length_causal_mask: 2_048 - use_learnable_pos_emb: true - - optim: - _target_: torch.optim.AdamW - lr: 1e-4 - - sched: - name: ExponentialLR - gamma: 0.998 - -trainer: - num_nodes: 1 - devices: -1 - accelerator: gpu - strategy: ddp_find_unused_parameters_true - precision: bf16-mixed - max_steps: ??? - accumulate_grad_batches: 1 - enable_checkpointing: false # Provided by exp_manager - logger: false # Provided by exp_manager - log_every_n_steps: 100 - limit_train_batches: 1_000 - val_check_interval: 1_000 - num_sanity_val_steps: 0 - benchmark: false - use_distributed_sampler: false # required because Lhotse has its own handling - gradient_clip_val: 2.5 - - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_wandb_logger: false - wandb_logger_kwargs: - entity: null - project: null - group: null - name: ${name} - resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. - create_checkpoint_callback: true - checkpoint_callback_params: - monitor: val_loss - mode: min - save_top_k: 5 - save_best_model: true - always_save_nemo: true - filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' - resume_if_exists: true - resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml deleted file mode 100644 index 14529a3172c5..000000000000 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ /dev/null @@ -1,234 +0,0 @@ -name: Magpie-TTS-ML-V1 - -max_epochs: ??? -# Adjust batch size based on GPU memory -batch_size: 16 -# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. -# If null, then weighted sampling is disabled. -weighted_sampling_steps_per_epoch: null - -# Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 -train_ds_meta: ??? -val_ds_meta: ??? - -model: - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts - context_duration_min: 3.0 - context_duration_max: 8.0 - speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise - load_cached_codes_if_available: true - prior_scaling_factor: 0.5 - prior_end_step: 12000 - prior_scaledown_start_step: 8000 - indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. - alignment_loss_scale: 0.002 - embedding_dim: 768 - codecmodel_path: ??? - max_epochs: ${max_epochs} - steps_per_epoch: ${weighted_sampling_steps_per_epoch} - cfg_unconditional_prob: 0.1 - # Alignment encoder parameters, to binarize the prior - # This is used for attention-constrained training and inference - use_alignment_encoder: false - # Below args are only relevant if use_alignment_encoder is true - use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder - alignment_encoder_loss_scale: 1.0 - binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. - binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. - prior_future_context: 2 # Future window of the binarized prior. - prior_past_context: 2 # Past window of the binarized prior. - prior_future_decay: 0.8 # Decay factor for future context - prior_past_decay: 0.5 # Decay factor for past context - binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs - binarized_prior_epsilon: 0.0 - aligner_encoder_train_steps: 50000 - - # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" - # Below args are only relevant if use_local_transformer is true - local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 3 - local_transformer_n_heads: 1 - local_transformer_hidden_dim: 256 - - text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. - text_context_remapping_prob: 0.0 - - text_tokenizers: # Add more languages for multi-lingual TTS - english_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - punct: true - apostrophe: true - pad_with_space: false - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" - heteronyms: "scripts/tts_dataset_files/heteronyms-052722" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - spanish_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: es-ES - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: es-ES - phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - german_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: de-DE - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: 'de-DE' - phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" - heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - grapheme_case: mixed - grapheme_prefix: '#' - mandarin_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p - phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" - word_segmenter: "jieba" - phoneme_prefix: "" - phoneme_case: "lower" - tone_prefix: "#" - ascii_letter_prefix: "" - ascii_letter_case: "upper" - - train_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${train_ds_meta} - weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} - min_duration: 0.2 - max_duration: 20.0 - - dataloader_params: - batch_size: ${batch_size} - num_workers: 4 - drop_last: true - pin_memory: true - - validation_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${val_ds_meta} - min_duration: 0.2 - max_duration: 20.0 - - dataloader_params: - batch_size: ${batch_size} - num_workers: 4 - pin_memory: true - - encoder: - n_layers: 6 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: false - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - decoder: - n_layers: 12 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 1 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: true - xa_d_memory: 768 - xa_n_heads: 1 - is_causal: true - apply_norm_to_cond: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - optim: - _target_: torch.optim.AdamW - lr: 1e-4 - - sched: - name: ExponentialLR - gamma: 0.998 - -trainer: - num_nodes: 1 - devices: -1 - accelerator: gpu - strategy: ddp_find_unused_parameters_true - precision: bf16-mixed - max_epochs: ${max_epochs} - accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager - log_every_n_steps: 100 - check_val_every_n_epoch: 1 - num_sanity_val_steps: 0 - benchmark: false - gradient_clip_val: 2.5 - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_wandb_logger: false - wandb_logger_kwargs: - name: null - project: null - resume: true - create_checkpoint_callback: true - checkpoint_callback_params: - monitor: val_loss - mode: min - save_top_k: 5 - save_best_model: true - always_save_nemo: true - filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' - resume_if_exists: true - resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml deleted file mode 100644 index b2be1234946e..000000000000 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml +++ /dev/null @@ -1,262 +0,0 @@ -name: Magpie-TTS-ML - -quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. -# Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 -train_ds_meta: ??? -val_ds_meta: ??? - -model: - use_lhotse: true - model_type: "decoder_ce" # single_encoder_sv_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. - text_conditioning_tokenizer_name: text_ce_tokenizer - context_duration_min: 5.0 - context_duration_max: 5.0 - load_cached_codes_if_available: true - prior_scaling_factor: 0.5 - prior_end_step: 12000 - prior_scaledown_start_step: 8000 - indefinite_prior_prob: 0. # If > 0, then prior will be applied after prior_end_step with this probability. - alignment_loss_scale: 0.002 - embedding_dim: 768 - codecmodel_path: ??? - cfg_unconditional_prob: 0.1 - # Alignment encoder parameters, to binarize the prior - # This is used for attention-constrained training and inference - use_alignment_encoder: false - - # Local transformer parameters for autoregressive codebook prediction within a frame - local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" - # Below args are only relevant if use_local_transformer is true - local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 1 - local_transformer_n_heads: 1 - local_transformer_hidden_dim: 256 - - text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. - text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. - - text_tokenizers: # Add more languages for multi-lingual TTS - english_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - punct: true - apostrophe: true - pad_with_space: false - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" - heteronyms: "scripts/tts_dataset_files/heteronyms-052722" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - spanish_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: es-ES - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: es-ES - phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - german_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: de-DE - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: 'de-DE' - phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" - heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - grapheme_case: mixed - grapheme_prefix: '#' - mandarin_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p - phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" - word_segmenter: "jieba" - phoneme_prefix: "" - phoneme_case: "lower" - tone_prefix: "#" - ascii_letter_prefix: "" - ascii_letter_case: "upper" - french_chartokenizer: - _target_: AutoTokenizer - pretrained_model: "google/byt5-small" - hindi_phoneme: - _target_: AutoTokenizer - pretrained_model: "google/byt5-small" - italian_phoneme: - _target_: AutoTokenizer - pretrained_model: "google/byt5-small" - vietnamese_phoneme: - _target_: AutoTokenizer - pretrained_model: "google/byt5-small" - text_ce_tokenizer: - _target_: AutoTokenizer - pretrained_model: "google/byt5-small" - - train_ds: - use_lhotse: ${model.use_lhotse} - volume_norm: true - - dataset: - min_duration: 0.2 - min_context_speaker_similarity: 0.6 - max_cer: 0.03 - batch_duration : ??? # in seconds. Adjust based on your GPU memory. - quadratic_duration: ${quadratic_duration} - use_bucketing: true - num_buckets: 20 - bucket_buffer_size: 20_000 - shuffle_buffer_size: 20_000 - num_cuts_for_bins_estimate: 20_000 - shard_seed: "trng" - drop_last: true - shuffle: true - num_workers: 6 - pin_memory: true - - input_cfg: - - type: lhotse_shar - shar_path: ??? - weight: 1.0 - tags: - tokenizer_names: ["english_phoneme"] - - - validation_ds: - use_lhotse: ${model.use_lhotse} - volume_norm: true - - dataset: - min_duration: 0.2 - min_context_speaker_similarity: 0.6 - max_cer: 0.03 - batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. - quadratic_duration: ${quadratic_duration} - use_bucketing: false - force_finite: true - drop_last: false - shuffle: false - num_workers: 2 - pin_memory: true - - input_cfg: - - type: lhotse_shar - shar_path: ??? - weight: 1.0 - tags: - tokenizer_names: ["english_phoneme"] - - encoder: - n_layers: 6 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - context_encoder: # Only used for multi_encoder_context_tts and decoder_ce - n_layers: 1 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: false - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - decoder: - n_layers: 12 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 1 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: true - xa_d_head: 128 - xa_d_memory: 768 - xa_n_heads: 1 - is_causal: true - apply_norm_to_cond: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - make_prior_window_strict: true - - optim: - _target_: torch.optim.AdamW - lr: 2e-4 - - sched: - name: ExponentialLR - gamma: 0.998 - -trainer: - num_nodes: 1 - devices: -1 - accelerator: gpu - strategy: ddp_find_unused_parameters_true - precision: 32 - max_steps: ??? - accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager - log_every_n_steps: 100 - check_val_every_n_epoch: 1 - limit_train_batches: 1_000 - val_check_interval: 1_000 - num_sanity_val_steps: 0 - benchmark: false - use_distributed_sampler: false # required because Lhotse has its own handling - gradient_clip_val: 2.5 - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_wandb_logger: false - wandb_logger_kwargs: - entity: null - project: null - group: null - name: ${name} - resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. - create_checkpoint_callback: true - checkpoint_callback_params: - monitor: val_loss - mode: min - save_top_k: 5 - save_best_model: true - always_save_nemo: true - filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' - resume_if_exists: true - resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_po_inference.yaml similarity index 76% rename from examples/tts/conf/magpietts/magpietts_inference_en.yaml rename to examples/tts/conf/magpietts/magpietts_po_inference.yaml index 993868ea6959..735e750a899e 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_po_inference.yaml @@ -1,4 +1,4 @@ -name: Magpie-TTS-EN-Infer +name: MagpieTTS-PO-Infer mode: test init_from_ptl_ckpt: ??? max_epochs: 1 @@ -9,32 +9,32 @@ batch_size: 16 weighted_sampling_steps_per_epoch: null # Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +# See DatasetMeta in https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/tts/data/text_to_speech_dataset.py test_ds_meta: ??? phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" model: + # Inference hyperparameters use_kv_cache_for_inference: true inference_temperature: 0.7 inference_topk: 80 inference_use_cfg: false inference_cfg_scale: 1.0 - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts - context_duration_min: 3.0 - context_duration_max: 8.0 - speaker_emb_dim: 192 max_decoder_steps: 500 + + model_type: "decoder_ce" # decoder_context_tts or decoder_ce + use_text_conditioning_encoder: true # Enable or disable text context. Audio context is always enabled + text_conditioning_tokenizer_name: text_ce_tokenizer # The tokenizer to be used for text contexts + context_duration_min: 5.0 + context_duration_max: 5.0 load_cached_codes_if_available: true prior_scaling_factor: null prior_end_step: 0 prior_scaledown_start_step: 0 alignment_loss_scale: 0.0 embedding_dim: 768 - codecmodel_path: null + codecmodel_path: ??? max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} @@ -58,14 +58,14 @@ model: local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" # Below args are only relevant if use_local_transformer is true local_transformer_loss_scale: 1.0 - local_transformer_n_layers: 3 + local_transformer_n_layers: 1 local_transformer_n_heads: 1 local_transformer_hidden_dim: 256 text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. - text_tokenizers: # Add more languages for multi-lingual TTS + text_tokenizers: english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer punct: true @@ -79,12 +79,19 @@ model: ignore_ambiguous_words: false use_chars: true use_stresses: true + text_ce_tokenizer: # Used for text context + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + ### For additional languages, consider adding a generic byt5 tokenizer like the one below + # french_chartokenizer: # Used for text context + # _target_: AutoTokenizer + # pretrained_model: "google/byt5-small" test_ds: dataset: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${test_ds_meta} - min_duration: 0.5 + min_duration: 0.2 max_duration: 20.0 dataloader_params: @@ -100,13 +107,13 @@ model: p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: false - is_causal: False + is_causal: true apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 + context_encoder: # Only used for decoder_ce (and multi_encoder_context_tts), ignored otherwise + n_layers: 1 d_model: 768 d_ffn: 3072 sa_n_heads: 12 @@ -124,22 +131,23 @@ model: d_model: 768 d_ffn: 3072 sa_n_heads: 12 - kernel_size: 3 + kernel_size: 1 p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: true + xa_d_head: 128 xa_d_memory: 768 - xa_n_heads: 12 + xa_n_heads: 1 is_causal: true apply_norm_to_cond: true apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true + make_prior_window_strict: true optim: - _target_: torch.optim.Adam + _target_: torch.optim.AdamW lr: 2e-4 - betas: [0.8, 0.99] sched: name: ExponentialLR @@ -166,8 +174,11 @@ exp_manager: create_tensorboard_logger: true create_wandb_logger: false wandb_logger_kwargs: - name: null + entity: null project: null + group: null + name: ${name} + resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. create_checkpoint_callback: true checkpoint_callback_params: monitor: val_loss @@ -175,5 +186,6 @@ exp_manager: save_top_k: 5 save_best_model: true always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' resume_if_exists: true - resume_ignore_no_checkpoint: true + resume_ignore_no_checkpoint: true \ No newline at end of file diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index c72da5318287..4676afad7626 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -27,7 +27,7 @@ from nemo.utils.exp_manager import exp_manager -@hydra_runner(config_path="conf/magpietts", config_name="magpietts_en") +@hydra_runner(config_path="conf/magpietts", config_name="magpietts_lhotse") def main(cfg): logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg, resolve=True)) diff --git a/examples/tts/t5tts_commands.md b/examples/tts/t5tts_commands.md deleted file mode 100644 index 4b76ceb157fb..000000000000 --- a/examples/tts/t5tts_commands.md +++ /dev/null @@ -1,462 +0,0 @@ -## Docker Container - -`nvcr.io/nvidia/nemo:dev` or `nvcr.io/nvidia/nemo:24.07` - -I have an sqsh file already built for `nvcr.io/nvidia/nemo:dev` - much faster to start on eos/draco than giving the above docker containers. -Sqsh path on EOS: `/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/launchscripts/nemodevNov24.sqsh` - -Docker commands I run locally - -``` -docker run --runtime=nvidia -it --rm -v --shm-size=16g --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v /home/pneekhara/2023:/home/pneekhara/2023 -v /datap/misc/:/datap/misc/ -v ~/.cache/torch:/root/.cache/torch -v ~/.netrc:/root/.netrc -v ~/.ssh:/root/.ssh --net=host nvcr.io/nvidia/nemo:dev -``` - -``` -cd /home/pneekhara/2023/SimpleT5NeMo/NeMo; export PYTHONPATH="/home/pneekhara/2023/SimpleT5NeMo/NeMo.:${PYTHONPATH}" ; -``` - -## Code -* Model `nemo/collections/tts/models/t5tts.py` -* Dataset Class `T5TTSDataset` in `nemo/collections/tts/data/text_to_speech_dataset.py` -* Transformer Module `nemo/collections/tts/modules/t5tts_transformer.py` -* Config Yaml `examples/tts/conf/t5tts/t5tts.yaml` -* Training/Inference Script `examples/tts/t5tts.py` - -## Model Types - -Currently supports three model types `single_encoder_sv_tts` , `multi_encoder_context_tts` and `decoder_context_tts` (`cfg.model.model_type` in t5tts.yaml) - -1. `single_encoder_sv_tts` is a simple T5 model: Text goes into the encoder and target audio goes to the decoder. - Additionally, speaker_embedding of target audio (or context audio if provided) from TitaNet gets added to encoder output (all timesteps). Text context is not supported in this model. - -2. `multi_encoder_context_tts` is a multi-encoder T5 model: Transcript and context audio go to different encoders. - Transcript encoding feeds to layers given by `cfg.model.transcript_decoder_layers` and the context encoding feeds into the layers given by `context_decoder_layers` . - Also supports text context which gets encoded by the same encoder as context audio. Only one of context audio or contex text is supported. - -3. `decoder_context_tts` : Text goes into the encoder; context & target audio go to the decoder. - Also supports text context. Currently, I have tested the model with using fixed sized context so I set `context_duration_min` and `context_duration_max` to the same value (5 seconds). Text context, which is usually shorter than number of codec frames of 5 second of audio, is padded to the max context duration in this model. - -4. `decoder_pretrain_synthesizer` : This is the model type used for pretraining the decoder only on audio data using next frame prediction loss. - -## Training - -### Manifest structure -For `single_encoder_sv_tts`, the manifest json files should contain the following keys: `audio_filepath, duration, text, speaker` . `speaker` is not currently being used so can be anything. Optionally, we can have a `context_audio_filepath` and `context_audio_duration` as well, if we want to use that for speaker embedding instead of the `audio_filepath`. -If we have already extracted the audio codes then they can also contain the key `target_audio_codes_path` pointing to the absolute path to the codes .pt file of shape (8, T). -Note: `target_audio_codes_path` should either be present in ALL training manifests or absent in ALL training manifest. Train set cannot be a mix of both. Same goes for val set. -If `target_audio_codes_path` is not present, codes are extracted on the fly (and training will be slower). - -For `multi_encoder_context_tts`, `decoder_context_tts`, in addition to the above, the manifest should contain `context_audio_filepath` and `context_audio_duration`. If we have codes already extracted, we can have `context_audio_codes_path` (abosolute path) instead of `context_audio_filepath`. - -For text context training, we can have `context_text` key for text context and drop `context_audio_duration` and `context_audio_filepath` (or `context_audio_codes_path`). - -If we have both `audio_filepath` and `target_audio_codes_path` in the manifest, the dataloader will load from `target_audio_codes_path`. To disable this and extract codes on the fly set the parameter `model.load_cached_codes_if_available=false` during training. Same goes for context audio. - -### Manifests and Datasets - -Manifests can be found in: `/lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/TTS/manifests` on draco-oci (`draco-oci-dc-02.draco-oci-iad.nvidia.com`) -I use the following for training. - -``` -Train: -hifitts__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json -rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json -rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_textContextsimplet5_withContextAudioPaths.json -libri100__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json -libri360__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json -mls17k__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_verified_simplet5_withContextAudioPaths.json - -Val: -dev_clean_withContextAudioPaths.json -``` - -Audio File Directories: -``` -HifiTTS: /lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/TTS/hi_fi_tts_v0 -Libri100, Libri360 Libri dev: /lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/TTS/LibriTTS -Lindy/Rodney: /lustre/fsw/portfolios/llmservice/users/pneekhara/gitrepos/TTS/riva -MLS Audio: /lustre/fsw/portfolios/edgeai/projects/edgeai_riva_rivamlops/data/tts/datasets/mls17k/filtered_24khz/audio_24khz -``` - -Pre-extracted Audio Codes (21 FPS with WavLM) -``` -/lustre/fs11/portfolios/edgeai/projects/edgeai_riva_rivamlops/data/tts/datasets/codecs -``` - -### Command -``` -python examples/tts/t5tts.py \ ---config-name=t5tts \ -max_epochs=1000 \ -weighted_sampling_steps_per_epoch=1000 \ -exp_manager.exp_dir="/datap/misc/Experiments/SimpleT5Explore/LocalTraining_LRH/" \ -+train_ds_meta.rivatrain.manifest_path="/home/pneekhara/2023/SimpleT5NeMo/manifests/rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5.json" \ -+train_ds_meta.rivatrain.audio_dir="/datap/misc/Datasets/riva" \ -+train_ds_meta.rivatrain.feature_dir="/datap/misc/Datasets/riva" \ -+train_ds_meta.rivatrain.sample_weight=1.0 \ -+train_ds_meta.libri360train.manifest_path="/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5.json" \ -+train_ds_meta.libri360train.audio_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ -+train_ds_meta.libri360train.feature_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ -+train_ds_meta.libri360train.sample_weight=1.0 \ -+train_ds_meta.libri100train.manifest_path="/home/pneekhara/2023/SimpleT5NeMo/manifests/libri100__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5.json" \ -+train_ds_meta.libri100train.audio_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ -+train_ds_meta.libri100train.feature_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ -+train_ds_meta.libri100train.sample_weight=1.0 \ -+val_ds_meta.librival.manifest_path="/home/pneekhara/2023/SimpleT5NeMo/manifests/dev_clean_withcontext.json" \ -+val_ds_meta.librival.audio_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ -+val_ds_meta.librival.feature_dir="/datap/misc/LibriTTSfromNemo/LibriTTS" \ -model.model_type="single_encoder_sv_tts" \ -model.use_text_conditioning_encoder=true \ -model.codecmodel_path="/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ -model.alignment_loss_scale=0.005 \ -model.prior_scaling_factor=0.5 \ -model.prior_scaledown_start_step=5000 \ -model.prior_end_step=8000 \ -model.context_duration_min=3.0 \ -model.context_duration_max=8.0 \ -model.train_ds.dataloader_params.num_workers=2 \ -model.validation_ds.dataloader_params.num_workers=2 \ -trainer.val_check_interval=500 \ -trainer.devices=-1 \ -~model.optim.sched ; -``` - -Audio filepaths in the manifests should be relative to `audio_dir`. Codec paths are absolute. - -Set `model.model_type=multi_encoder_context_tts` for Multi Encoder T5TTS or `decoder_context_tts` for decoder context and `model.use_text_conditioning_encoder=true` if you want both audio/text contexts. - -### Command Lhotse dataset -``` -python examples/tts/t5tts.py \ - --config-name=t5tts_lhotse.yaml \ - batch_size=32 \ - micro_batch_size=32 \ - max_steps=1000000 \ - limit_val_batches=20 \ - trainer.max_steps=1000000 \ - trainer.val_check_interval=500 \ - exp_manager.exp_dir="/datap/misc/Experiments/SimpleT5Explore/LocalTraining_LRH/" \ - model.codecmodel_path="/home/ecasanova/Projects/Checkpoints/Audio_codec/21Hz-no-eliz/AudioCodec_21Hz_no_eliz.nemo" \ - model.alignment_loss_scale=0.01 \ - model.prior_scaling_factor=0.5 \ - model.prior_scaledown_start_step=5000 \ - model.prior_end_step=8000 \ - model.t5_encoder.use_flash_self_attention=true \ - model.t5_encoder.use_flash_x_attention=true \ - model.t5_decoder.use_flash_self_attention=true \ - model.t5_decoder.use_flash_x_attention=false \ - trainer.devices=1 \ - ++model.load_cached_codes_if_available=False \ - ++model.num_audio_codebooks=8 \ - ++model.num_audio_tokens_per_codebook=2048 \ - ++model.codec_model_downsample_factor=1024 \ - ~model.optim.sched ; - -HYDRA_FULL_ERROR=1 PYTHONFAULTHANDLER=1 python examples/tts/t5tts.py \ - --config-name=t5tts_lhotse.yaml \ - exp_manager.exp_dir="/datap/misc/Experiments/SimpleT5Explore/LocalTraining_LRH/" \ - +exp_manager.version=0 \ - eval_batch_size=64 \ - batch_size=384 \ - micro_batch_size=24 \ - max_steps=5000000 \ - batch_duration=350 \ - limit_val_batches=25 \ - trainer.max_steps=5000000 \ - model.codecmodel_path="/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/gitrepos/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ - ++model.train_ds.dataset.input_cfg.0.type="lhotse_shar" \ - ++model.train_ds.dataset.input_cfg.0.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/hifitts_v0/" \ - ++model.train_ds.dataset.input_cfg.0.weight=1.0 \ - ++model.train_ds.dataset.input_cfg.0.tags.lang="en" \ - ++model.train_ds.dataset.input_cfg.0.tags.s2s=True \ - ++model.train_ds.dataset.input_cfg.0.tags.tokenizer_names=["english_phoneme"] \ - ++model.train_ds.dataset.input_cfg.1.type="lhotse_shar" \ - ++model.train_ds.dataset.input_cfg.1.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/libri100/" \ - ++model.train_ds.dataset.input_cfg.1.weight=1.0 \ - ++model.train_ds.dataset.input_cfg.1.tags.lang="en" \ - ++model.train_ds.dataset.input_cfg.1.tags.s2s=True \ - ++model.train_ds.dataset.input_cfg.1.tags.tokenizer_names=["english_phoneme"] \ - ++model.train_ds.dataset.input_cfg.2.type="lhotse_shar" \ - ++model.train_ds.dataset.input_cfg.2.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/rivaLindyRodney/" \ - ++model.train_ds.dataset.input_cfg.2.weight=1.0 \ - ++model.train_ds.dataset.input_cfg.2.tags.lang="en" \ - ++model.train_ds.dataset.input_cfg.2.tags.s2s=True \ - ++model.train_ds.dataset.input_cfg.2.tags.tokenizer_names=["english_phoneme"] \ - ++model.train_ds.dataset.input_cfg.3.type="lhotse_shar" \ - ++model.train_ds.dataset.input_cfg.3.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/libri360/" \ - ++model.train_ds.dataset.input_cfg.3.weight=1.0 \ - ++model.train_ds.dataset.input_cfg.3.tags.lang="en" \ - ++model.train_ds.dataset.input_cfg.3.tags.s2s=True \ - ++model.train_ds.dataset.input_cfg.3.tags.tokenizer_names=["english_phoneme"] \ - ++model.validation_ds.dataset.input_cfg.0.type="lhotse_shar" \ - ++model.validation_ds.dataset.input_cfg.0.shar_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/tts_lhotse_datasets/LibriTTS_dev_clean/" \ - ++model.validation_ds.dataset.input_cfg.0.weight=1.0 \ - ++model.validation_ds.dataset.input_cfg.0.tags.lang="en" \ - ++model.validation_ds.dataset.input_cfg.0.tags.s2s=True \ - ++model.validation_ds.dataset.input_cfg.0.tags.tokenizer_names=["english_phoneme"] \ - model.alignment_loss_scale=0.01 \ - model.prior_scaling_factor=0.5 \ - model.prior_scaledown_start_step=5000 \ - model.prior_end_step=8000 \ - model.t5_encoder.use_flash_self_attention=true \ - model.t5_encoder.use_flash_x_attention=true \ - model.t5_decoder.use_flash_self_attention=true \ - model.t5_decoder.use_flash_x_attention=false \ - trainer.val_check_interval=50 \ - trainer.devices=8 \ - ++model.load_cached_codes_if_available=False \ - ++model.num_audio_codebooks=8 \ - ++model.num_audio_tokens_per_codebook=2048 \ - ++model.codec_model_downsample_factor=1024 \ - model.optim.lr=2e-4 \ - trainer.num_nodes=${SLURM_JOB_NUM_NODES} - -``` -Set `model.model_type=multi_encoder_context_tts` for Multi Encoder T5TTS and `model.use_text_conditioning_encoder=true` if you are doing text context training. - -If you change the codec model, make sure to adjust these model config params in `t5tts.yaml`: - -``` -model: - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 4 extra for eos/bos ids - codec_model_downsample_factor: 1024 -``` - -To train then model without CTC loss and prior, set the below params: - -``` -model.alignment_loss_scale=0.0 \ -model.prior_scaling_factor=null \ -``` - -### Training sub files on cluster - -| Model Type | Cluster | Training Sub File | -|------------|---------|--------| -| multi_encoder_context_tts | draco-oci-login-01.draco-oci-iad.nvidia.com |/lustre/fsw/portfolios/llmservice/users/pneekhara/launchscripts/unnormalized_me.sub | -| decoder_context_tts | draco-oci-login-01.draco-oci-iad.nvidia.com | /lustre/fsw/portfolios/llmservice/users/pneekhara/launchscripts/unnormalizedt5_decoder.sub | -| single_encoder_sv_tts | draco-oci-login-01.draco-oci-iad.nvidia.com | /lustre/fsw/portfolios/llmservice/users/pneekhara/launchscripts/unnormalizedt5_singleencoder.sub | -| decoder_pretrain_synthesizer | login-eos | /lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/scriptsSimpleT5/newt5_pretrain.sub | - -## Pretrained Models and Results - -Paths to pretrained checkpoints and their evaluation results on some test sets can be found [here](https://docs.google.com/spreadsheets/d/16AkvAHZ-ytWYnzEx9wtOG7yLkuU2wfB8gGMiDa5sROg/edit?usp=sharing) - -## Inference and Eval - -To infer and evaluate from a given checkpoint and hparams.yaml file I use `scripts/t5tts/infer_and_evaluate.py`. To evaluate on a given manifest (same structure as discussed above), edit the `dataset_meta_info` in `scripts/t5tts/infer_and_evaluate.py` to point to the paths on your machine or add any other datasets if missing. - -``` -dataset_meta_info = { - 'vctk': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths.json', - 'audio_dir' : '/datap/misc/Datasets/VCTK-Corpus', - 'feature_dir' : '/datap/misc/Datasets/VCTK-Corpus', - }, - 'riva_challenging': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/challengingLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withContextAudioPaths.json', - 'audio_dir' : '/datap/misc/Datasets/riva', - 'feature_dir' : '/datap/misc/Datasets/riva', - }, - 'libri_val': { - 'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri360_val.json', - 'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - 'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS', - } -} -``` - -Then run - -``` -python scripts/t5tts/infer_and_evaluate.py \ ---hparams_file \ ---checkpoint_file \ ---codecmodel_path /datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo \ ---datasets "vctk,libri_val" \ ---out_dir /datap/misc/Evals \ ---temperature 0.6 \ ---topk 80 -``` - -Ignore the other params in the file, I also use this for evaluating ongoing experiments on the cluster by copying over the checkpoints and hparams.. - -### Inference Notebook - -Inference Notebook: `t5tts_inference.ipynb` For quickly trying custom texts/contexts. - -### Offline Preference Alignment (DPO/RPO) - -Code: `nemo/collections/tts/models/t5tts_preference_optimization.py` - -Preference Alignment (DPO/RPO) involves the following steps -1) Create a list of text-context pairs for which we will generate preference data. -2) For each text-context pair generate multiple audios from a base T5-TTS checkpoint and calculate metrics (CER/SSIM) for each generation. -3) Create chosen-rejected pairs from the generated audio. -4) Finetune the base T5-TTS checkpoint on the chosen-rejected pairs. - -#### 1. Create text-context pairs -We pair a list of challenging texts with context audios from from Riva and LibriTTS dataset. We add a similar number of regular texts from LibriTTS and Riva (paired with random context audios). We also include examples with text contexts. There are other options for generating text-context pairs. - -``` -python scripts/t5tts/dpo/create_text_contextpairs.py \ - --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ - --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ - --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ - --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ - --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ - --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \ - --nsamples_perpair 6 ; -``` -Each pair is repeated `nsamples_perpair` times which specifies how many samples we want to generate for each pair. The output manifest serves as the input for the next step. - -We can also explore other options for these text-context pairs as well depending on the task. - -#### 2. Generate audios for each text-context pair - -Next, we can generate audios from a base T5-TTS checkpoint using the following command. We pass the `audio_dir` as "/" since our text context pairs contains absolute paths. Model config arguments should be modified accordingly to match the base checkpoint architecture. We can run the below command on cluster to generate audios across multiple nodes. This command saves the generated audios along with the metrics for each generation in the `exp_dir`. Each generated audio file is accompanied with a `.json` file that has the CER/SSIM metrics. - -Sample sub file on EOS: `/lustre/fsw/llmservice_nemo_speechlm/users/shehzeenh/launchscripts/newdatagendpo_decoder.sub` - -``` -python examples/tts/t5tts.py \ ---config-name=t5tts_inference \ -mode=test \ -batch_size=64 \ -+init_from_ptl_ckpt="/mountdir/checkpoints/continuouscheckpoints_ks1_ks3/decodercontext_small_282.ckpt" \ -exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282" \ -+test_ds_meta.textcontextpairs.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json" \ -+test_ds_meta.textcontextpairs.audio_dir="/" \ -+test_ds_meta.textcontextpairs.feature_dir="/" \ -model.model_type="decoder_context_tts" \ -model.t5_encoder.kernel_size=3 \ -model.t5_decoder.kernel_size=1 \ -model.context_duration_min=5.0 \ -model.context_duration_max=5.0 \ -model.use_text_conditioning_encoder=true \ -model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ -model.alignment_loss_scale=0.002 \ -model.prior_scaling_factor=null \ -model.load_cached_codes_if_available=false \ -trainer.num_nodes=${SLURM_JOB_NUM_NODES} -``` -#### 3. Create chosen-rejected pairs from the generations - -Next, we go through the generated audio directory and create chosen-rejected pairs. - -``` -python scripts/t5tts/dpo/create_preference_pairs.py \ ---input_manifest /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/manifests/dpo_textcontext_pairs.json \ ---generated_audio_dir /lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/audios \ ---group_size 6 \ ---cer_threshold 0.01 \ ---val_size 256 ; -``` - -`cer_threshold=0.01` means that filter out pairs in which the chosen CER > 0.01. - -This command should save train and val manifests for DPO finetuning in the base directory of the generated_audio_dir, that is, `/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/` - -#### 4. DPO Finetuning Command - -Finally, we perform DPO finetuning using the following command: - -``` -python examples/tts/t5tts.py \ -batch_size=4 \ -+init_from_ptl_ckpt="/mountdir/checkpoints/decoder_21_epoch_2.ckpt" \ -+mode="dpo_train" \ -max_epochs=10 \ -exp_manager.exp_dir="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/TrainingsICML/decodercontext_small_282" \ -exp_manager.checkpoint_callback_params.always_save_nemo=false \ -model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \ -model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \ -+train_ds_meta.dpopreftrain.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_train_manifest.json" \ -+train_ds_meta.dpopreftrain.audio_dir="/" \ -+train_ds_meta.dpopreftrain.feature_dir="/" \ -+val_ds_meta.dpoprefval.manifest_path="/lustre/fsw/llmservice_nemo_speechlm/data/TTS/DPOData/Generations/decodercontext_small_282/T5TTS/version_0/manifests/dpo_val_manifest.json" \ -+val_ds_meta.dpoprefval.audio_dir="/" \ -+val_ds_meta.dpoprefval.feature_dir="/" \ -+model.dpo_beta=0.01 \ -+model.dpo_sft_loss_weight=0.0 \ -model.model_type="decoder_context_tts" \ -model.context_duration_min=5.0 \ -model.context_duration_max=5.0 \ -model.use_text_conditioning_encoder=true \ -model.codecmodel_path="/mountdir/checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ -model.alignment_loss_scale=0.001 \ -model.prior_scaling_factor=null \ -trainer.val_check_interval=200 \ -trainer.log_every_n_steps=10 \ -model.optim.lr=2e-7 \ -~model.optim.sched \ -trainer.num_nodes=${SLURM_JOB_NUM_NODES} -``` - -Note the following overrides in the above command: - -``` -+mode="dpo_train" \ -model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \ -model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.T5TTSDatasetDPO" \ -``` - -Again, our manifest contain absolute paths so we specify `audio_dir="/"` . - -### Online Preference Optimization (GRPO) - -For online preference optmization, process is much simpler. - -1) Create a list of text-context pairs for which we will generate preference data (just one pair for a text-context not repeated). -We'll use the same process as above, just set `nsamples_perpair 1` in the command. -``` -python scripts/t5tts/dpo/create_text_contextpairs.py \ - --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ - --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ - --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ - --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ - --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ - --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \ - --nsamples_perpair 1 ; -``` - -2. Train using GRPO - -``` -python examples/tts/t5tts.py \ -+mode="onlinepo_train" \ -+init_from_ptl_ckpt="/Data/ICML2025_CKPTS/icml2025_base_checkpoints/decodercontext_small_sp_ks3CorrectWithPrior_onlyphoneme_epoch161.ckpt" \ -max_epochs=1000 \ -exp_manager.exp_dir="/Data/Experiments/NewT5TTSGRPO/Try3NoDropoutBeta0.01_CFG/" \ -+train_ds_meta.grpotrainnomls.manifest_path="/Data/DPOPairsInputDatav2/text_context_pairs_grpo_train_nomls.json" \ -+train_ds_meta.grpotrainnomls.audio_dir="/" \ -+train_ds_meta.grpotrainnomls.feature_dir="/" \ -+val_ds_meta.grpovalnomls.manifest_path="/Data/DPOPairsInputDatav2/text_context_pairs_grpo_val_unseenspeakers_tinysubset.json" \ -+val_ds_meta.grpovalnomls.audio_dir="/" \ -+val_ds_meta.grpovalnomls.feature_dir="/" \ -+model.num_generations_per_item=6 \ -+model.grpo_beta=0.01 \ -model.t5_decoder.p_dropout=0.0 \ -model.t5_encoder.p_dropout=0.0 \ -model.model_type="decoder_context_tts" \ -model.use_text_conditioning_encoder=true \ -model.context_duration_min=5.0 \ -model.context_duration_max=5.0 \ -model.codecmodel_path="/Data/Checkpoints/AudioCodec_21Hz_no_eliz.nemo" \ -model.alignment_loss_scale=0.0 \ -model.prior_scaling_factor=null \ -model.train_ds.dataloader_params.num_workers=0 \ -model.validation_ds.dataloader_params.num_workers=0 \ -exp_manager.checkpoint_callback_params.monitor="val_mean_reward" \ -exp_manager.checkpoint_callback_params.mode="max" \ -+trainer.use_distributed_sampler=False \ -+model.inference_cfg_prob=0.5 \ -+model.inference_cfg_scale=2.5 \ -batch_size=2 \ -model.optim.lr=1e-6 \ -trainer.devices=2 \ -trainer.log_every_n_steps=1 \ -trainer.val_check_interval=50 \ -~model.optim.sched ; -``` \ No newline at end of file diff --git a/scripts/magpietts/README_magpie_po.md b/scripts/magpietts/README_magpie_po.md index 8ba3d2e66655..897287aaf1d5 100644 --- a/scripts/magpietts/README_magpie_po.md +++ b/scripts/magpietts/README_magpie_po.md @@ -32,7 +32,7 @@ Next, we can generate audios from a base TTS checkpoint using the following comm ``` python examples/tts/magpietts.py \ ---config-name=magpietts_inference_en \ +--config-name=magpietts_po_inference \ mode=test \ batch_size=64 \ +init_from_ptl_ckpt= \ @@ -56,7 +56,7 @@ Next, we go through the generated audio directory and create chosen-rejected pai ``` python scripts/magpietts/dpo/create_preference_pairs.py \ --input_manifest \ ---generated_audio_dir /Magpie-TTS-EN-Infer/version_0/audio \ +--generated_audio_dir /MagpieTTS-PO-Infer/version_0/audio \ --group_size 6 \ --cer_threshold 0.01 \ --val_size 256 ; @@ -64,7 +64,7 @@ python scripts/magpietts/dpo/create_preference_pairs.py \ `cer_threshold=0.01` means that filter out pairs in which the chosen CER > 0.01. -This command should save train and val manifests for DPO finetuning in the base directory of the generated_audio_dir, that is, `/Magpie-TTS-EN-Infer/version_0/manifests/` +This command should save train and val manifests for DPO finetuning in the base directory of the generated_audio_dir, that is, `/MagpieTTS-PO-Infer/version_0/manifests/` #### 4. DPO Finetuning Command @@ -80,10 +80,10 @@ exp_manager.exp_dir= \ exp_manager.checkpoint_callback_params.always_save_nemo=false \ model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ -+train_ds_meta.dpopreftrain.manifest_path="/Magpie-TTS-EN-Infer/version_0/manifests/" \ ++train_ds_meta.dpopreftrain.manifest_path="/MagpieTTS-PO-Infer/version_0/manifests/" \ +train_ds_meta.dpopreftrain.audio_dir="/" \ +train_ds_meta.dpopreftrain.feature_dir="/" \ -+val_ds_meta.dpoprefval.manifest_path="/Magpie-TTS-EN-Infer/version_0/manifests/dpo_val_manifest.json" \ ++val_ds_meta.dpoprefval.manifest_path="/MagpieTTS-PO-Infer/version_0/manifests/dpo_val_manifest.json" \ +val_ds_meta.dpoprefval.audio_dir="/" \ +val_ds_meta.dpoprefval.feature_dir="/" \ +model.dpo_beta=0.01 \