diff --git a/.github/workflows/cicd-main-speech.yml b/.github/workflows/cicd-main-speech.yml index 5e21b397ca57..a2bb3c2e7da9 100644 --- a/.github/workflows/cicd-main-speech.yml +++ b/.github/workflows/cicd-main-speech.yml @@ -187,6 +187,16 @@ jobs: - runner: self-hosted-azure script: SPEECHLM_HF_Training_SALM timeout: 20 + - runner: self-hosted-azure + script: L2_TTS_Fast_dev_runs_Magpietts_DecoderContext + - runner: self-hosted-azure + script: L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder + - runner: self-hosted-azure + script: L2_TTS_Fast_dev_runs_Magpietts_OnlinePO + - runner: self-hosted-azure + script: L2_TTS_InferEvaluate_Magpietts_ZeroShot + - runner: self-hosted-azure + script: L2_TTS_InferEvaluate_Magpietts_SeenSpeakers needs: [unit-tests] runs-on: ${{ matrix.runner }} name: ${{ matrix.is-optional && 'PLEASEFIXME_' || '' }}${{ matrix.script }} diff --git a/examples/tts/README_frame_stacking.md b/examples/tts/README_frame_stacking.md new file mode 100644 index 000000000000..3be970c0446a --- /dev/null +++ b/examples/tts/README_frame_stacking.md @@ -0,0 +1,26 @@ +# Overview +This PR introduces frame-stacking implementation in Magpie-TTS. Frame-stacking is disabled by default. It can be enabled by setting a `frame_stacking_factor` > 1 in the YAML config. + +# Frame-stacking + +## Overview +Frame-stacking is a technique that allows the Magpie-TTS **base decoder** (also known as the "main" for "first stage" decoder) to **process multiple consecutive audio frames in a single forward pass**, leaving the job of generating individual frames and codebooks to a second, smaller, "Local Transformer" ("LT") decoder. The goal is to accelerate inference by reducing the number of generation steps of the base decoder. In this two-stage approach: + +1. The base decoder processes multiple frames at once, producing a single latent representation for each group (stack) of frames +2. The Local Transformer then generates the individual `frames * codebooks` tokens. + +The Local Transformer is much faster than the base decoder, making this two-stage approach significantly faster than generating each frame with the base decoder. The speed improvement comes from two factors: +* **Fewer parameters**: The LT decoder is lightweight compared to the base decoder +* **Shorter sequences**: The LT decoder only attends to the current frame stack and the latent, not the entire frame sequence + +The base decoder can also generate audio codes directly without a LT, but when frame-stacking is enabled using the LT decoder is typically necessary to achieve high-quality synthesis. + +## Design and Implementation +* The `frame_stacking_factor` is the parameter that controls the number of frames to stack. The default is 1, which means no frame-stacking. We have tested values up to `4`. +* For each codebooks, we keep a separate embedding table for at each frame within the stack. At the input to the decoder, the embeddings are averages across codebooks (as usual) and also frames within the stack. The embedding tables are shared between the base and LT decoders. + +## Limitations +This is still WIP with more work to be done. Specifically, the following are not yet implemented / tested: +* Online code extraction combined with frame-stacking. +* Alignment encoder with frame-stacking. +* CTC loss with frame-stacking. \ No newline at end of file diff --git a/examples/tts/README_magpietts_legacy_checkpoints.md b/examples/tts/README_magpietts_legacy_checkpoints.md new file mode 100644 index 000000000000..a6a048ea39ab --- /dev/null +++ b/examples/tts/README_magpietts_legacy_checkpoints.md @@ -0,0 +1,95 @@ +# Background +Magpie-TTS uses special tokens like AUDIO_BOS and AUDIO_EOS for its operation. The indices of these tokens are after the audio codec tokens, at the end of the embedding table. + +In April 2025 we changed the layout of the embedding table in a non-backwards compatible way: + +## Old Layout (until April 16) +With the most common codec configuration (2016 codes), the layout used to look like this: +``` +| Index | Token Description | Comments | +|---------|----------------------|-----------------------------------------------------------------------------------------------------------| +| [0] | Codec Token 0 | | +| [1] | Codec Token 1 | | +| [2] | Codec Token 2 | | +| ... | ... | | +| [2015] | Codec Token 2015 | | +| [2016] | | | +| [2017] | | | +| [2018] | | | +| ... | | | +| [2044] | Context Audio BOS | if model_type == `decoder_context_tts` | +| [2045] | Context Audio EOS | if model_type == `decoder_context_tts` | +| [2046] | Audio BOS | also used for Context Audio BOS if model_type == `multi_encoder_context_tts` or `single_encoder_sv_tts` | +| [2047] | Audio EOS | also used for Context Audio EOS if model_type == `multi_encoder_context_tts` or `single_encoder_sv_tts` | +``` + +## New Layout``` +The new layout for the same codec configuration is: +``` +| Index | Token Description | Comments | +---------------------------------------------| +| [0] | Codec Token 0 | | +| [1] | Codec Token 1 | | +| [2] | Codec Token 2 | | +| ... | ... | | +| [2015] | Codec Token 2015 | | +| [2016] | Audio BOS | | +| [2017] | Audio EOS | | +| [2018] | Context Audio BOS | | +| [2019] | Context Audio EOS | | +| [2020] | MASK token (MaskGit) | | +| [2021] | RESERVED_1 | | +| [2022] | RESERVED_2 | | +| [2023] | RESERVED_3 | | +``` + +# How to Train and Load a New Checkpoint +For new trainings and inference all configuration is automatic: +* The number of codebooks, codec codebooks size, and codec downsampling rate are all read from the codec checkpoint rather than configured in Magpie. +* The embedding table size is automatically set to codec_codebook_size + number_of_special_tokens (currently 2016+8=2024). There is no risk of accidentally stepping on codec tokens since the table sizes gets automatically sized with enough room for the special tokens. + +# How to Load Old Checkpoints +For checkpoints created before the change you can force legacy codebook layout in one of these ways: + +## If using `infer_and_evaluate.py` +Just set the `--legacy_codebooks` command line option. No need to update your YAML file – The script will automatically add the overrides. + +## If using a Hydra command line +This scenario would happen when either finetuning with an old checkpoint or doing data generation with an old checkpoint. + +You have two options: +### Add these to your command line +``` +# decoder context model ++model.forced_num_all_tokens_per_codebook=2048 +model.forced_audio_eos_id=2047 +model.forced_audio_bos_id=2046 +model.forced_context_audio_eos_id=2045 +model.forced_context_audio_bos_id=2044 + +# multi encoder context and any other model type ++model.forced_num_all_tokens_per_codebook=2048 +model.forced_audio_eos_id=2047 +model.forced_audio_bos_id=2046 +model.forced_context_audio_eos_id=2047 +model.forced_context_audio_bos_id=2046 +``` +# Or, add these overrides to your YAML file +``` +forced_num_all_tokens_per_codebook: 2048 +forced_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -1} # 2047 +forced_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -2} # 2046 + +# Depending on the old model type, the context_audio_bos_id and context_audio_eos_id will be different (choose one of the pairs below) + +# For `multi_encoder_context_tts`, `single_encoder_sv_tts`: +#forced_context_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -1} # 2047 +#forced_context_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -2} # 2046 + +# For `decoder_context_tts` models: +#forced_context_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -3} # 2045 +#forced_context_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -4} # 2044 +``` + +# Additional Details +Over the last few weeks we have gone through a few embedding table layouts. When using an old checkpoint it's important to know which layout your checkpoint was trained with and configuring the system accordingly. + +* Layout 1: used until April 16 (described in the table above). Add `--legacy-codebooks` to the `infer_and_evaluate.py` command line to inference using this layout. + +* Layout 2: after the [config changes](https://github.com/blisc/NeMo/commit/7e2cdca74a866ecefdbe01c0076ad9b5d140ac61): 2018 tokens with special tokens at the end 2017, 2016, 2015, 2014 (the last two being overwrites of codec tokens). This is an invalid layout and these checkpoints should not be used. + +* Layout 3: after the [bugfix](https://github.com/blisc/NeMo/commit/23e299a0bd14b666543b4bbcc7783f783acb0bd3) but before the [refactoring](https://github.com/blisc/NeMo/commit/8ba55061a0ebb161abff4b329e402d5307f4af98): 2024 tokens with special tokens at the end (2023, 2022, 2021, 2020). There are no automatic options provided for using this layout but it can be manually configured by updating the `hparams.yaml` file with the `forced_*` options. Set `forced_num_all_tokens_per_codebook` to `2024` and set the rest of the overrides as defined under section `# Or, add these overrides to your YAML file` above. + +* Layout 4: The new layout, [from this commit onwards](https://github.com/blisc/NeMo/commit/8ba55061a0ebb161abff4b329e402d5307f4af98): 2024 tokens but with special tokens immediately after codec tokens (2016, 2017, 2018, 2019). Training and inference with the latest version of the code automatically use this layout. diff --git a/examples/tts/conf/magpietts/magpietts.yaml b/examples/tts/conf/magpietts/magpietts.yaml new file mode 100644 index 000000000000..4c45f38fb4b3 --- /dev/null +++ b/examples/tts/conf/magpietts/magpietts.yaml @@ -0,0 +1,199 @@ +name: Magpie-TTS + +max_epochs: ??? +# Adjust batch size based on GPU memory +batch_size: 16 +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. +weighted_sampling_steps_per_epoch: null + +# Dataset metadata for each manifest +# See DatasetMeta in https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/tts/data/text_to_speech_dataset.py +train_ds_meta: ??? +val_ds_meta: ??? + +model: + model_type: "decoder_ce" # decoder_context_tts or decoder_ce + use_text_conditioning_encoder: true # Enable or disable text context. Audio context is always enabled + text_conditioning_tokenizer_name: text_ce_tokenizer # The tokenizer to be used for text contexts + context_duration_min: 5.0 + context_duration_max: 5.0 + load_cached_codes_if_available: true + prior_scaling_factor: 0.5 + prior_end_step: 12000 + prior_scaledown_start_step: 8000 + indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability. + alignment_loss_scale: 0.002 + embedding_dim: 768 + codecmodel_path: ??? + max_epochs: ${max_epochs} + steps_per_epoch: ${weighted_sampling_steps_per_epoch} + cfg_unconditional_prob: 0.1 + + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + + text_tokenizers: + english_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: false + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + text_ce_tokenizer: # Used for text context + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + ### For additional languages, consider adding a generic byt5 tokenizer like the one below + # french_chartokenizer: # Used for text context + # _target_: AutoTokenizer + # pretrained_model: "google/byt5-small" + + train_ds: + dataset: + _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset + dataset_meta: ${train_ds_meta} + weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} + min_duration: 0.2 + max_duration: 20.0 + + dataloader_params: + batch_size: ${batch_size} + num_workers: 4 + drop_last: true + pin_memory: true + + validation_ds: + dataset: + _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset + dataset_meta: ${val_ds_meta} + min_duration: 0.2 + max_duration: 20.0 + + dataloader_params: + batch_size: ${batch_size} + num_workers: 4 + pin_memory: true + + encoder: + n_layers: 6 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: true + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + context_encoder: # Only used for decoder_ce (and multi_encoder_context_tts), ignored otherwise + n_layers: 1 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: false + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + decoder: + n_layers: 12 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 1 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: true + xa_d_head: 128 + xa_d_memory: 768 + xa_n_heads: 1 + is_causal: true + apply_norm_to_cond: true + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + make_prior_window_strict: true + + optim: + _target_: torch.optim.AdamW + lr: 2e-4 + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: 32 + max_epochs: ${max_epochs} + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 1 + num_sanity_val_steps: 0 + benchmark: false + gradient_clip_val: 2.5 + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_wandb_logger: false + wandb_logger_kwargs: + entity: null + project: null + group: null + name: ${name} + resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' + resume_if_exists: true + resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_en.yaml b/examples/tts/conf/magpietts/magpietts_en.yaml deleted file mode 100644 index e71d9ac5d261..000000000000 --- a/examples/tts/conf/magpietts/magpietts_en.yaml +++ /dev/null @@ -1,169 +0,0 @@ -name: Magpie-TTS-EN - -max_epochs: ??? -# Adjust batch size based on GPU memory -batch_size: 16 -# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. -# If null, then weighted sampling is disabled. -weighted_sampling_steps_per_epoch: null - -# Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 -train_ds_meta: ??? -val_ds_meta: ??? - -# Modify these values based on your sample rate -sample_rate: 22050 - -model: - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts - context_duration_min: 3.0 - context_duration_max: 8.0 - speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids - codec_model_downsample_factor: 1024 - load_cached_codes_if_available: true - prior_scaling_factor: 0.5 - prior_end_step: 12000 - prior_scaledown_start_step: 8000 - alignment_loss_scale: 0.0 - embedding_dim: 768 - codecmodel_path: ??? - max_epochs: ${max_epochs} - steps_per_epoch: ${weighted_sampling_steps_per_epoch} - - sample_rate: ${sample_rate} - - text_tokenizers: # Add more languages for multi-lingual TTS - english_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - punct: true - apostrophe: true - pad_with_space: false - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" - heteronyms: "scripts/tts_dataset_files/heteronyms-052722" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - - train_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${train_ds_meta} - weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} - min_duration: 0.5 - max_duration: 20.0 - - dataloader_params: - batch_size: ${batch_size} - num_workers: 4 - drop_last: true - - validation_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${val_ds_meta} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} - min_duration: 0.5 - max_duration: 20.0 - - dataloader_params: - batch_size: ${batch_size} - num_workers: 0 - - encoder: - n_layers: 6 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: False - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: false - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - decoder: - n_layers: 12 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: true - xa_d_memory: 768 - xa_n_heads: 12 - is_causal: true - apply_norm_to_cond: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - optim: - _target_: torch.optim.Adam - lr: 2e-4 - betas: [0.8, 0.99] - - sched: - name: ExponentialLR - gamma: 0.998 - -trainer: - num_nodes: 1 - devices: -1 - accelerator: gpu - strategy: ddp_find_unused_parameters_true - precision: 32 - max_epochs: ${max_epochs} - accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager - log_every_n_steps: 100 - val_check_interval: 500 - # check_val_every_n_epoch: 10 - benchmark: false - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_wandb_logger: false - wandb_logger_kwargs: - name: null - project: null - create_checkpoint_callback: true - checkpoint_callback_params: - monitor: val_loss - mode: min - save_top_k: 5 - save_best_model: true - always_save_nemo: true - resume_if_exists: true - resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml deleted file mode 100644 index 9831131092ed..000000000000 --- a/examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml +++ /dev/null @@ -1,210 +0,0 @@ -name: Magpie-TTS-ML-V1-Infer -mode: test -init_from_ptl_ckpt: ??? -max_epochs: 1 -# Adjust batch size based on GPU memory -batch_size: 16 -# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. -# If null, then weighted sampling is disabled. -weighted_sampling_steps_per_epoch: null - -# Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 -test_ds_meta: ??? - -# Modify these values based on your sample rate -sample_rate: 22050 - -phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" -heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" -model: - use_kv_cache_for_inference: true - inference_temperature: 0.7 - inference_topk: 80 - inference_use_cfg: false - inference_cfg_scale: 1.0 - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts - context_duration_min: 3.0 - context_duration_max: 8.0 - speaker_emb_dim: 192 - max_decoder_steps: 500 - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids - codec_model_downsample_factor: 1024 - load_cached_codes_if_available: true - prior_scaling_factor: null - prior_end_step: 0 - prior_scaledown_start_step: 0 - alignment_loss_scale: 0.0 - embedding_dim: 768 - codecmodel_path: null - max_epochs: ${max_epochs} - steps_per_epoch: ${weighted_sampling_steps_per_epoch} - - sample_rate: ${sample_rate} - - text_tokenizers: # Add more languages for multi-lingual TTS - english_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - punct: true - apostrophe: true - pad_with_space: false - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" - heteronyms: "scripts/tts_dataset_files/heteronyms-052722" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - spanish_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: es-ES - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: es-ES - phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - german_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: de-DE - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: 'de-DE' - phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" - heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - grapheme_case: mixed - grapheme_prefix: '#' - mandarin_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p - phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" - word_segmenter: "jieba" - phoneme_prefix: "" - phoneme_case: "lower" - tone_prefix: "#" - ascii_letter_prefix: "" - ascii_letter_case: "upper" - multilingual_sentencepiece: - _target_: AutoTokenizer - pretrained_model: "bert-base-multilingual-uncased" - - test_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${test_ds_meta} - sample_rate: ${sample_rate} - min_duration: 0.5 - max_duration: 20.0 - # speaker_path: ${speaker_path} - - dataloader_params: - batch_size: ${batch_size} - num_workers: 2 - - encoder: - n_layers: 6 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: False - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: false - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - decoder: - n_layers: 12 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: true - xa_d_memory: 768 - xa_n_heads: 12 - is_causal: true - apply_norm_to_cond: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - optim: - _target_: torch.optim.Adam - lr: 2e-4 - betas: [0.8, 0.99] - - sched: - name: ExponentialLR - gamma: 0.998 - -trainer: - num_nodes: 1 - devices: -1 - accelerator: gpu - strategy: ddp_find_unused_parameters_true - precision: 32 - max_epochs: ${max_epochs} - accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager - log_every_n_steps: 100 - val_check_interval: 500 - # check_val_every_n_epoch: 10 - benchmark: false - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_wandb_logger: false - wandb_logger_kwargs: - name: null - project: null - create_checkpoint_callback: true - checkpoint_callback_params: - monitor: val_loss - mode: min - save_top_k: 5 - save_best_model: true - always_save_nemo: true - resume_if_exists: true - resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_lhotse.yaml b/examples/tts/conf/magpietts/magpietts_lhotse.yaml new file mode 100644 index 000000000000..241fddfff232 --- /dev/null +++ b/examples/tts/conf/magpietts/magpietts_lhotse.yaml @@ -0,0 +1,217 @@ +name: Magpie-TTS + +quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. +model: + use_lhotse: true + model_type: "decoder_ce" # decoder_context_tts or decoder_ce + use_text_conditioning_encoder: true # Enable or disable text context. Audio context is always enabled + text_conditioning_tokenizer_name: text_ce_tokenizer # The tokenizer to be used for text contexts + context_duration_min: 5.0 + context_duration_max: 5.0 + load_cached_codes_if_available: true + prior_scaling_factor: 0.5 + prior_end_step: 12000 + prior_scaledown_start_step: 8000 + indefinite_prior_prob: 0. # If > 0, then prior will be applied after prior_end_step with this probability. + alignment_loss_scale: 0.002 + embedding_dim: 768 + codecmodel_path: ??? + cfg_unconditional_prob: 0.1 + + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 + + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + + text_tokenizers: + english_phoneme: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: false + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + text_ce_tokenizer: # Used for text context + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + ### For additional languages, consider adding a generic byt5 tokenizer like the one below + # french_chartokenizer: # Used for text context + # _target_: AutoTokenizer + # pretrained_model: "google/byt5-small" + + train_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration : ??? # in seconds. Adjust based on your GPU memory. + quadratic_duration: ${quadratic_duration} + use_bucketing: true + num_buckets: 20 + bucket_buffer_size: 20_000 + shuffle_buffer_size: 20_000 + num_cuts_for_bins_estimate: 20_000 + shard_seed: "trng" + drop_last: true + shuffle: true + num_workers: 6 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + + validation_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. + quadratic_duration: ${quadratic_duration} + use_bucketing: false + force_finite: true + drop_last: false + shuffle: false + num_workers: 2 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + encoder: + n_layers: 6 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: true + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + context_encoder: # Only used for decoder_ce (and multi_encoder_context_tts), ignored otherwise + n_layers: 1 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 3 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: false + is_causal: false + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + + decoder: + n_layers: 12 + d_model: 768 + d_ffn: 3072 + sa_n_heads: 12 + kernel_size: 1 + p_dropout: 0.1 + p_dropout_out: 0.0 + has_xattn: true + xa_d_head: 128 + xa_d_memory: 768 + xa_n_heads: 1 + is_causal: true + apply_norm_to_cond: true + apply_norm_out: true + max_length_causal_mask: 2048 + use_learnable_pos_emb: true + make_prior_window_strict: true + + optim: + _target_: torch.optim.AdamW + lr: 2e-4 + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: 32 + max_steps: ??? + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 1 + limit_train_batches: 1_000 + val_check_interval: 1_000 + num_sanity_val_steps: 0 + benchmark: false + use_distributed_sampler: false # required because Lhotse has its own handling + gradient_clip_val: 2.5 + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_wandb_logger: false + wandb_logger_kwargs: + entity: null + project: null + group: null + name: ${name} + resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' + resume_if_exists: true + resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml b/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml deleted file mode 100644 index 82d9ec6f3890..000000000000 --- a/examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml +++ /dev/null @@ -1,217 +0,0 @@ -name: Magpie-TTS-ML-V1 - -max_epochs: ??? -# Adjust batch size based on GPU memory -batch_size: 16 -# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. -# If null, then weighted sampling is disabled. -weighted_sampling_steps_per_epoch: null - -# Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 -train_ds_meta: ??? -val_ds_meta: ??? - -# Modify these values based on your sample rate -sample_rate: 22050 - -model: - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false # If true, distilbert will be used to encode context_text if provided. - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts - context_duration_min: 3.0 - context_duration_max: 8.0 - speaker_emb_dim: 192 # Only used for single_encoder_sv_tts, ignored otherwise - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids - codec_model_downsample_factor: 1024 - load_cached_codes_if_available: true - prior_scaling_factor: 0.5 - prior_end_step: 12000 - prior_scaledown_start_step: 8000 - alignment_loss_scale: 0.0 - embedding_dim: 768 - codecmodel_path: ??? - max_epochs: ${max_epochs} - steps_per_epoch: ${weighted_sampling_steps_per_epoch} - - sample_rate: ${sample_rate} - - text_tokenizers: # Add more languages for multi-lingual TTS - english_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - punct: true - apostrophe: true - pad_with_space: false - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" - heteronyms: "scripts/tts_dataset_files/heteronyms-052722" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - spanish_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: es-ES - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: es-ES - phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - german_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer - locale: de-DE - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p - locale: 'de-DE' - phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" - heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" - phoneme_probability: 0.8 - ignore_ambiguous_words: false - use_chars: true - use_stresses: true - grapheme_case: mixed - grapheme_prefix: '#' - mandarin_phoneme: - _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer - punct: true - apostrophe: true - pad_with_space: true - g2p: - _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p - phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" - word_segmenter: "jieba" - phoneme_prefix: "" - phoneme_case: "lower" - tone_prefix: "#" - ascii_letter_prefix: "" - ascii_letter_case: "upper" - multilingual_sentencepiece: - _target_: AutoTokenizer - pretrained_model: "bert-base-multilingual-uncased" - - train_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${train_ds_meta} - weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} - min_duration: 0.5 - max_duration: 20.0 - - dataloader_params: - batch_size: ${batch_size} - num_workers: 4 - drop_last: true - - validation_ds: - dataset: - _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset - dataset_meta: ${val_ds_meta} - sample_rate: ${sample_rate} - # speaker_path: ${speaker_path} - min_duration: 0.5 - max_duration: 20.0 - - dataloader_params: - batch_size: ${batch_size} - num_workers: 0 - - encoder: - n_layers: 6 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: False - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: false - is_causal: false - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - decoder: - n_layers: 12 - d_model: 768 - d_ffn: 3072 - sa_n_heads: 12 - kernel_size: 3 - p_dropout: 0.1 - p_dropout_out: 0.0 - has_xattn: true - xa_d_memory: 768 - xa_n_heads: 12 - is_causal: true - apply_norm_to_cond: true - apply_norm_out: true - max_length_causal_mask: 2048 - use_learnable_pos_emb: true - - optim: - _target_: torch.optim.Adam - lr: 2e-4 - betas: [0.8, 0.99] - - sched: - name: ExponentialLR - gamma: 0.998 - -trainer: - num_nodes: 1 - devices: -1 - accelerator: gpu - strategy: ddp_find_unused_parameters_true - precision: 32 - max_epochs: ${max_epochs} - accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager - log_every_n_steps: 100 - val_check_interval: 500 - # check_val_every_n_epoch: 10 - benchmark: false - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_wandb_logger: false - wandb_logger_kwargs: - name: null - project: null - create_checkpoint_callback: true - checkpoint_callback_params: - monitor: val_loss - mode: min - save_top_k: 5 - save_best_model: true - always_save_nemo: true - resume_if_exists: true - resume_ignore_no_checkpoint: true diff --git a/examples/tts/conf/magpietts/magpietts_inference_en.yaml b/examples/tts/conf/magpietts/magpietts_po_inference.yaml similarity index 51% rename from examples/tts/conf/magpietts/magpietts_inference_en.yaml rename to examples/tts/conf/magpietts/magpietts_po_inference.yaml index eec62db5547d..735e750a899e 100644 --- a/examples/tts/conf/magpietts/magpietts_inference_en.yaml +++ b/examples/tts/conf/magpietts/magpietts_po_inference.yaml @@ -1,4 +1,4 @@ -name: Magpie-TTS-EN-Infer +name: MagpieTTS-PO-Infer mode: test init_from_ptl_ckpt: ??? max_epochs: 1 @@ -9,44 +9,63 @@ batch_size: 16 weighted_sampling_steps_per_epoch: null # Dataset metadata for each manifest -# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 +# See DatasetMeta in https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/tts/data/text_to_speech_dataset.py test_ds_meta: ??? -# Modify these values based on your sample rate -sample_rate: 22050 - phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" model: + # Inference hyperparameters use_kv_cache_for_inference: true inference_temperature: 0.7 inference_topk: 80 inference_use_cfg: false inference_cfg_scale: 1.0 - model_type: "multi_encoder_context_tts" # single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts or decoder_pretrain_synthesizer - use_text_conditioning_encoder: false - transcript_decoder_layers: [3,4,5,6,7] # ONLY used in multi_encoder_context_tts, Transcript goes to these layer ids, context goes to the rest. In single_encoder_sv_tts, all layers are used for transcript. - context_decoder_layers: [8,9] # ONLY used in multi_encoder_context_tts - context_duration_min: 3.0 - context_duration_max: 8.0 - speaker_emb_dim: 192 max_decoder_steps: 500 - num_audio_codebooks: 8 - num_audio_tokens_per_codebook: 2048 # Keep atleast 2 extra for eos/bos ids - codec_model_downsample_factor: 1024 + + model_type: "decoder_ce" # decoder_context_tts or decoder_ce + use_text_conditioning_encoder: true # Enable or disable text context. Audio context is always enabled + text_conditioning_tokenizer_name: text_ce_tokenizer # The tokenizer to be used for text contexts + context_duration_min: 5.0 + context_duration_max: 5.0 load_cached_codes_if_available: true prior_scaling_factor: null prior_end_step: 0 prior_scaledown_start_step: 0 alignment_loss_scale: 0.0 embedding_dim: 768 - codecmodel_path: null + codecmodel_path: ??? max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample_rate} + # Alignment encoder parameters, to binarize the prior + # This is used for attention-constrained training and inference + use_alignment_encoder: false + # Below args are only relevant if use_alignment_encoder is true + use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder + alignment_encoder_loss_scale: 1.0 + binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step. + binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax. + prior_future_context: 2 # Future window of the binarized prior. + prior_past_context: 2 # Past window of the binarized prior. + prior_future_decay: 0.8 # Decay factor for future context + prior_past_decay: 0.5 # Decay factor for past context + binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs + binarized_prior_epsilon: 0.0 + aligner_encoder_train_steps: 50000 + + # Local transformer parameters for autoregressive codebook prediction within a frame + local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" + # Below args are only relevant if use_local_transformer is true + local_transformer_loss_scale: 1.0 + local_transformer_n_layers: 1 + local_transformer_n_heads: 1 + local_transformer_hidden_dim: 256 - text_tokenizers: # Add more languages for multi-lingual TTS + text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. + text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. + + text_tokenizers: english_phoneme: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer punct: true @@ -60,15 +79,20 @@ model: ignore_ambiguous_words: false use_chars: true use_stresses: true + text_ce_tokenizer: # Used for text context + _target_: AutoTokenizer + pretrained_model: "google/byt5-small" + ### For additional languages, consider adding a generic byt5 tokenizer like the one below + # french_chartokenizer: # Used for text context + # _target_: AutoTokenizer + # pretrained_model: "google/byt5-small" test_ds: dataset: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${test_ds_meta} - sample_rate: ${sample_rate} - min_duration: 0.5 + min_duration: 0.2 max_duration: 20.0 - # speaker_path: ${speaker_path} dataloader_params: batch_size: ${batch_size} @@ -83,13 +107,13 @@ model: p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: false - is_causal: False + is_causal: true apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true - context_encoder: # Only used for multi_encoder_context_tts, ignored otherwise - n_layers: 3 + context_encoder: # Only used for decoder_ce (and multi_encoder_context_tts), ignored otherwise + n_layers: 1 d_model: 768 d_ffn: 3072 sa_n_heads: 12 @@ -107,22 +131,23 @@ model: d_model: 768 d_ffn: 3072 sa_n_heads: 12 - kernel_size: 3 + kernel_size: 1 p_dropout: 0.1 p_dropout_out: 0.0 has_xattn: true + xa_d_head: 128 xa_d_memory: 768 - xa_n_heads: 12 + xa_n_heads: 1 is_causal: true apply_norm_to_cond: true apply_norm_out: true max_length_causal_mask: 2048 use_learnable_pos_emb: true + make_prior_window_strict: true optim: - _target_: torch.optim.Adam + _target_: torch.optim.AdamW lr: 2e-4 - betas: [0.8, 0.99] sched: name: ExponentialLR @@ -140,8 +165,8 @@ trainer: logger: false # Provided by exp_manager log_every_n_steps: 100 val_check_interval: 500 - # check_val_every_n_epoch: 10 benchmark: false + gradient_clip_val: 2.5 exp_manager: exp_dir: null @@ -149,8 +174,11 @@ exp_manager: create_tensorboard_logger: true create_wandb_logger: false wandb_logger_kwargs: - name: null + entity: null project: null + group: null + name: ${name} + resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. create_checkpoint_callback: true checkpoint_callback_params: monitor: val_loss @@ -158,5 +186,6 @@ exp_manager: save_top_k: 5 save_best_model: true always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}' resume_if_exists: true - resume_ignore_no_checkpoint: true + resume_ignore_no_checkpoint: true \ No newline at end of file diff --git a/examples/tts/magpietts.py b/examples/tts/magpietts.py index af6a6f9a1752..4676afad7626 100644 --- a/examples/tts/magpietts.py +++ b/examples/tts/magpietts.py @@ -12,46 +12,91 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl +import torch.multiprocessing as mp from omegaconf import OmegaConf, open_dict -from nemo.collections.tts.models import MagpieTTS_Model, MagpieTTS_ModelDPO, MagpieTTS_ModelInference +from nemo.collections.tts.models import ( + MagpieTTSModel, + MagpieTTSModelOfflinePO, + MagpieTTSModelOfflinePODataGen, + MagpieTTSModelOnlinePO, +) from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager -@hydra_runner(config_path="conf/magpietts", config_name="magpietts_en") +@hydra_runner(config_path="conf/magpietts", config_name="magpietts_lhotse") def main(cfg): logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg, resolve=True)) - if not cfg.model.get('use_lthose', False): - import torch.multiprocessing as mp - mp.set_start_method("spawn", force=True) + # forcing "spawn" method for multiprocessing over "fork" when choosing multiple + # worker processes for dataloaders. By default, multiprocessing uses "fork" to create + # worker processes, which inherit the memory state of the main process, including its + # already initialized CUDA state. When the worker processes trieds to use + # CUDA, it runs into conflicts with the inherited, now potentially invalid, + # CUDA context, resuling in the CUDA initialization error. When + # num_workers=0, all dataloading happens in the main process, so there is no + # process forking and no CUDA context conflict. When num_workers>0, the standard way + # to fix this is to use "spawn" to create a completely new and clean python process for + # each worker, avoding the problematic CUDA state inheritance. + mp.set_start_method("spawn", force=True) trainer = pl.Trainer(**cfg.trainer) + trainer.callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='step', log_weight_decay=True)) exp_manager(trainer, cfg.get("exp_manager", None)) - if cfg.get('mode', 'train') == 'train': - model = MagpieTTS_Model(cfg=cfg.model, trainer=trainer) - elif cfg.get('mode', 'dpo_train') == 'dpo_train': + seed = cfg.get('seed', None) + if seed is not None: + # Option to seed for debugging + logging.info(f"Setting seed to {seed}") + pl.seed_everything(seed, workers=True) + + mode = cfg.get('mode', 'train') + if mode == 'train': + model = MagpieTTSModel(cfg=cfg.model, trainer=trainer) + elif mode == 'dpo_train': + model_cfg = cfg.model + with open_dict(model_cfg): + model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt + model = MagpieTTSModelOfflinePO(cfg=model_cfg, trainer=trainer) + elif mode == 'onlinepo_train': model_cfg = cfg.model with open_dict(model_cfg): model_cfg.reference_model_ckpt_path = cfg.init_from_ptl_ckpt - model = MagpieTTS_ModelDPO(cfg=model_cfg, trainer=trainer) - elif cfg.get('mode', 'train') == 'test': - model = MagpieTTS_ModelInference(cfg=cfg.model, trainer=trainer) + model = MagpieTTSModelOnlinePO(cfg=model_cfg, trainer=trainer) + elif mode == 'test': + model = MagpieTTSModelOfflinePODataGen(cfg=cfg.model, trainer=trainer) else: - raise NotImplementedError(f"Only train, dpo_train and test modes are supported. Got {cfg.mode}") + raise NotImplementedError(f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}") model.maybe_init_from_pretrained_checkpoint(cfg=cfg) - if cfg.get('mode', 'train') in ['train', 'dpo_train']: - trainer.fit(model) - elif cfg.get('mode', 'train') == 'test': - trainer.test(model) - else: - raise NotImplementedError(f"Only train and test modes are supported. Got {cfg.mode}") + try: + if mode in ['train', 'dpo_train', 'onlinepo_train']: + logging.info("Starting training...") + trainer.fit(model) + elif mode == 'test': + logging.info("Starting testing...") + trainer.test(model) + else: + raise NotImplementedError( + f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}" + ) + logging.info("Training/testing completed successfully.") + finally: + # Ensure WandB completes all uploads before Python thread shutdown + # Critical when num_workers=0 during debugging - the main process can become + # overwhelmed and fail to properly coordinate with WandB's background threads + try: + import wandb + + if wandb.run is not None: + logging.info("Finishing WandB run to prevent threading shutdown hang...") + wandb.finish() + except Exception as e: + logging.warning(f"Error finishing WandB: {e}") if __name__ == '__main__': diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index f62b15f98dc0..0447fc2b26a6 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -45,6 +45,8 @@ ) from nemo.collections.common.data.lhotse.sampling import ( BucketingFilter, + CERFilter, + ContextSpeakerSimilarityFilter, DurationFilter, FixedBucketBatchSizeConstraint2D, MultimodalFixedBucketBatchSizeConstraint2D, @@ -52,6 +54,7 @@ TokenCountFilter, TokenPerSecondFilter, TokenPerTokenFilter, + ValidationStatusFilter, ) from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn from nemo.collections.common.prompts import PromptFormatter @@ -131,6 +134,13 @@ class LhotseDataLoadingConfig: min_tpt: int = -1 # allowed tokens per token (text-only) max_tpt: Any = float("inf") # float | list[float] + # 2.3 Filters on CER and/or cosine speaker similarity of the context audio serving for TTS use cases. + max_cer: float | None = float("inf") + min_context_speaker_similarity: float | None = -1 + + # 2.4 Filters on validation status. If the validation status is not "pass", the cut will be filtered out. + keep: str = "pass" + # 3. Supported existing NeMo options. shuffle: bool = False sample_rate: int = 16000 @@ -230,7 +240,7 @@ def get_lhotse_dataloader_from_config( tokenizer=None, ) -> torch.utils.data.DataLoader: """ - Set up a Lhotse training dataloder. + Set up a Lhotse training dataloader. Expects a typical NeMo dataset configuration format, with additional fields: "use_lhotse=True". Some fields in the original NeMo configuration may be ignored. @@ -276,7 +286,7 @@ def get_lhotse_dataloader_from_single_config( tokenizer=None, ) -> torch.utils.data.DataLoader: """ - Set up a Lhotse training dataloder. + Set up a Lhotse training dataloader. Expects a typical NeMo dataset configuration format, with additional fields: "use_lhotse=True". Some fields in the original NeMo configuration may be ignored. @@ -549,6 +559,13 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No TokenCountFilter(config.min_tokens, config.max_tokens, measure_total_length=config.measure_total_length) ) + # validation status filtering + cuts = cuts.filter(ValidationStatusFilter(config.keep)) + # CER filtering, same as native NeMo dataloaders. + cuts = cuts.filter(CERFilter(config.max_cer)) + # Context speaker similarity filtering, same as native NeMo dataloaders. + cuts = cuts.filter(ContextSpeakerSimilarityFilter(config.min_context_speaker_similarity)) + if tokenizer is not None and config.pretokenize: cuts = cuts.filter(TokenPerSecondFilter(config.min_tps, config.max_tps)) cuts = cuts.filter(TokenPerTokenFilter(config.min_tpt, config.max_tpt)) diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 96ff66f3791d..0c927058c234 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -18,7 +18,7 @@ from typing import Any, Sequence import numpy as np -from lhotse.cut import Cut +from lhotse.cut import Cut, MonoCut from lhotse.dataset import SamplingConstraint, TokenConstraint from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint from lhotse.utils import ifnone @@ -268,6 +268,66 @@ def __call__(self, example) -> bool: return True # does not apply to text etc. +class ValidationStatusFilter: + """ + Callable, returns ``True`` if a cut's validation status is equal to keep and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__(self, keep: str = "pass") -> None: + self.keep = keep + + def __call__(self, example) -> bool: + if ( + isinstance(example, MonoCut) + and example.has_custom("validation_status") + and example.validation_status != self.keep + ): + return False + else: + return True + + +class CERFilter: + """ + Callable, returns ``True`` if a cut's CER is less than max_cer and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__(self, max_cer: float | None) -> None: + self.max_cer = ifnone(max_cer, float("inf")) + + def __call__(self, example) -> bool: + if ( + isinstance(example, MonoCut) + and len(example.supervisions) > 0 + and example.supervisions[0].has_custom("cer") + ): + return example.supervisions[0].cer <= self.max_cer + else: + return True + + +class ContextSpeakerSimilarityFilter: + """ + Callable, returns ``True`` if a cut's context speaker similarity is greater than min_context_speaker_similarity and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__(self, min_context_speaker_similarity: float | None) -> None: + self.min_context_speaker_similarity = ifnone(min_context_speaker_similarity, -1) + + def __call__(self, example) -> bool: + if ( + isinstance(example, MonoCut) + and len(example.supervisions) > 0 + and example.supervisions[0].has_custom("context_speaker_similarity") + ): + return example.supervisions[0].context_speaker_similarity >= self.min_context_speaker_similarity + else: + return True + + class TokenCountFilter: """ Callable, returns ``True`` if an example's number of tokens is in range [t_min, t_max] and ``False`` otherwise. diff --git a/nemo/collections/common/parts/utils.py b/nemo/collections/common/parts/utils.py index 6479dc315ba2..fce56eeb4820 100644 --- a/nemo/collections/common/parts/utils.py +++ b/nemo/collections/common/parts/utils.py @@ -159,12 +159,16 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor): class ClampActivation(nn.Module): - def __init__(self, min_value: float = -1.0, max_value: float = 1.0): + def __init__(self, min_value: float = -1.0, max_value: float = 1.0, clamp_training: bool = True): super().__init__() self.min_value = min_value self.max_value = max_value + self.clamp_training = clamp_training def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.training and not self.clamp_training: + return input + return torch.clamp(input, min=self.min_value, max=self.max_value) diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index 988cb853a9e8..9cd20cfc83b5 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -37,23 +37,23 @@ vietnamese_text_preprocessing, ) from nemo.utils import logging -from nemo.utils.decorators import experimental class BaseTokenizer(ABC): + """Abstract class for creating an arbitrary tokenizer to convert string to list of int tokens. + Args: + tokens: List of tokens. + pad: Pad token as string. + blank: Blank token as string. + oov: OOV token as string. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + """ + PAD, BLANK, OOV = '', '', '' def __init__(self, tokens, *, pad=PAD, blank=BLANK, oov=OOV, sep='', add_blank_at=None): - """Abstract class for creating an arbitrary tokenizer to convert string to list of int tokens. - Args: - tokens: List of tokens. - pad: Pad token as string. - blank: Blank token as string. - oov: OOV token as string. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - """ super().__init__() tokens = list(tokens) @@ -95,6 +95,18 @@ def decode(self, tokens: List[int]) -> str: class BaseCharsTokenizer(BaseTokenizer): + """Base class for char-based tokenizer. + Args: + chars: string that represents all possible characters. + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + """ + # fmt: off # TODO @xueyang: unify definition of the default PUNCT_LIST and import from ipa_lexicon.py PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally @@ -114,17 +126,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=lambda x: x, ): - """Base class for char-based tokenizer. - Args: - chars: string that represents all possible characters. - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - """ tokens = [] self.space, tokens = len(tokens), tokens + [' '] # Space @@ -175,6 +176,18 @@ def encode(self, text): class EnglishCharsTokenizer(BaseCharsTokenizer): + """English char-based tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Basically, it replaces all non-unicode characters with unicode ones and apply lower() function. + """ + def __init__( self, punct=True, @@ -184,17 +197,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=english_text_preprocessing, ): - """English char-based tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Basically, it replaces all non-unicode characters with unicode ones and apply lower() function. - """ super().__init__( chars=string.ascii_lowercase, punct=punct, @@ -207,6 +209,17 @@ def __init__( class VietnameseCharsTokenizer(BaseCharsTokenizer): + """Vietnamese grapheme tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. By default, it + would keep any word lowercase. + """ _LOCALE = "vi-VN" _CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="mixed") @@ -221,17 +234,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=vietnamese_text_preprocessing, ): - """Vietnamese grapheme tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. By default, it - would keep any word lowercase. - """ super().__init__( chars=chars, punct=punct, @@ -244,6 +246,17 @@ def __init__( class GermanCharsTokenizer(BaseCharsTokenizer): + """German grapheme-based tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. By default, it + would keep any word unchanged. + """ _LOCALE = "de-DE" _PUNCT_LIST = get_ipa_punctuation_list(_LOCALE) @@ -259,17 +272,6 @@ def __init__( non_default_punct_list=_PUNCT_LIST, text_preprocessing_func=any_locale_text_preprocessing, ): - """German grapheme-based tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. By default, it - would keep any word unchanged. - """ super().__init__( chars=chars, punct=punct, @@ -282,6 +284,15 @@ def __init__( class SpanishCharsTokenizer(BaseCharsTokenizer): + """Spanish grapheme tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + """ PUNCT_LIST = get_ipa_punctuation_list("es-ES") @@ -293,15 +304,6 @@ def __init__( pad_with_space=False, non_default_punct_list=None, ): - """Spanish grapheme tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - """ es_alphabet = "abcdefghijklmnopqrstuvwxyzáéíñóúü" super().__init__( @@ -316,6 +318,15 @@ def __init__( class FrenchCharsTokenizer(BaseCharsTokenizer): + """French grapheme tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + """ PUNCT_LIST = get_ipa_punctuation_list("fr-FR") @@ -327,15 +338,6 @@ def __init__( pad_with_space=False, non_default_punct_list=None, ): - """French grapheme tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - """ fr_alphabet = get_grapheme_character_set(locale="fr-FR", case="lower") super().__init__( @@ -350,20 +352,21 @@ def __init__( class ItalianCharsTokenizer(BaseCharsTokenizer): + """Italian grapheme tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + """ + PUNCT_LIST = get_ipa_punctuation_list("it-IT") def __init__( self, punct=True, apostrophe=True, add_blank_at=None, pad_with_space=False, non_default_punct_list=None ): - """Italian grapheme tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - """ it_alphabet = "abcdefghijklmnopqrstuvwxyzàèéìòùó" super().__init__( @@ -378,6 +381,18 @@ def __init__( class GermanPhonemesTokenizer(BaseCharsTokenizer): + """Deutsch phoneme-based tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Currently, it only applies lower() function. + """ + # fmt: off PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally ',', '.', '!', '?', '-', @@ -395,17 +410,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=any_locale_text_preprocessing, ): - """Deutsch phoneme-based tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Currently, it only applies lower() function. - """ de_ipa = "abdefhijklmnoprstuvwxyzçðøŋœɐɑɒɔəɛɜɡɪɹɾʃʊʌʒː̃" de_suprasegmentals = "12" @@ -449,6 +453,18 @@ def encode(self, text): class ItalianPhonemesTokenizer(BaseCharsTokenizer): + """Italian phoneme-based tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Currently, it only applies lower() function. + """ + # fmt: off PUNCT_LIST = ( ',', '.', '!', '?', '-', @@ -467,17 +483,6 @@ def __init__( non_default_punct_list=None, text_preprocessing_func=italian_text_preprocessing, ): - """Italian phoneme-based tokenizer. - Args: - punct: Whether to reserve grapheme for basic punctuation or not. - apostrophe: Whether to use apostrophe or not. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Currently, it only applies lower() function. - """ it_ipa = ( "abcdefghijklmnopqrstuvwxyzàèéìòùóæɐɑɔəɚɜɬɹʌʔᵻðŋɛɡɣɪɲɾʃʊʎʒʝβθd͡'t͡'øɒɕɓçɖɘɝɞɟʄɡɠɢʛɦɧħɥʜɨɬɫɮʟɱɯɰɳɵɸœɶʘɺ" @@ -523,6 +528,28 @@ def encode(self, text): class EnglishPhonemesTokenizer(BaseTokenizer): + """English phoneme-based tokenizer. + Args: + g2p: Grapheme to phoneme module. + punct: Whether to reserve grapheme for basic punctuation or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + stresses: Whether to use phonemes codes with stresses (0-2) or not. + chars: Whether to additionally use chars together with phonemes. It is useful if g2p module can return + chars too. + space: Space token as string. + silence: Silence token as string (will be disabled if it is None). + apostrophe: Whether to use apostrophe or not. + oov: OOV token as string. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Basically, it replaces all non-unicode characters with unicode ones. + Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be + handled by g2p). + """ + # fmt: off PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally ',', '.', '!', '?', '-', @@ -559,27 +586,6 @@ def __init__( pad_with_space=False, text_preprocessing_func=lambda text: english_text_preprocessing(text, lower=False), ): - """English phoneme-based tokenizer. - Args: - g2p: Grapheme to phoneme module. - punct: Whether to reserve grapheme for basic punctuation or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - stresses: Whether to use phonemes codes with stresses (0-2) or not. - chars: Whether to additionally use chars together with phonemes. It is useful if g2p module can return - chars too. - space: Space token as string. - silence: Silence token as string (will be disabled if it is None). - apostrophe: Whether to use apostrophe or not. - oov: OOV token as string. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Basically, it replaces all non-unicode characters with unicode ones. - Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be - handled by g2p). - """ self.phoneme_probability = None if hasattr(g2p, "phoneme_probability"): @@ -674,6 +680,7 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): @contextmanager def set_phone_prob(self, prob): + """Updates the phone probability inside context""" if hasattr(self.g2p, "phoneme_probability"): self.g2p.phoneme_probability = prob try: @@ -683,8 +690,31 @@ def set_phone_prob(self, prob): self.g2p.phoneme_probability = self.phoneme_probability -@experimental class IPATokenizer(BaseTokenizer): + """General-purpose IPA-based tokenizer. + Args: + g2p: Grapheme to phoneme module, should be IpaG2p or some subclass thereof. + locale: Locale used to determine default text processing logic and punctuation. + Supports ["en-US", "de-DE", "es-ES", "fr-FR"]. Defaults to "en-US". + Specify None if implementing custom logic for a new locale. + punct: Whether to reserve grapheme for basic punctuation or not. + non_default_punct_list: List of punctuation marks which will be used instead default, if any. + fixed_vocab: List of valid grapheme/phoneme tokens for the model. + Set only if overriding the default vocab generation process (reading from G2P dict). + If set, any dataset entries that have unincluded graphemes will be filtered out, and any words whose + pronunciations have unincluded phonemes will be treated as OOV. + Please make sure that the grapheme prefixes and cases are consistent with the G2P module's settings. + Defaults to None, which means default vocab generation is used. + space: Space token as string. + silence: Silence token as string (will be disabled if it is None). + apostrophe: Whether to use apostrophe or not. + oov: OOV token as string. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + """ + def __init__( self, g2p, @@ -701,29 +731,6 @@ def __init__( add_blank_at=None, pad_with_space=False, ): - """General-purpose IPA-based tokenizer. - Args: - g2p: Grapheme to phoneme module, should be IpaG2p or some subclass thereof. - locale: Locale used to determine default text processing logic and punctuation. - Supports ["en-US", "de-DE", "es-ES", "fr-FR"]. Defaults to "en-US". - Specify None if implementing custom logic for a new locale. - punct: Whether to reserve grapheme for basic punctuation or not. - non_default_punct_list: List of punctuation marks which will be used instead default, if any. - fixed_vocab: List of valid grapheme/phoneme tokens for the model. - Set only if overriding the default vocab generation process (reading from G2P dict). - If set, any dataset entries that have unincluded graphemes will be filtered out, and any words whose - pronunciations have unincluded phonemes will be treated as OOV. - Please make sure that the grapheme prefixes and cases are consistent with the G2P module's settings. - Defaults to None, which means default vocab generation is used. - space: Space token as string. - silence: Silence token as string (will be disabled if it is None). - apostrophe: Whether to use apostrophe or not. - oov: OOV token as string. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - """ if not hasattr(g2p, "symbols"): logging.error( f"Please make sure the G2P module passed into the IPATokenizer has a `symbols` attribute. " @@ -848,6 +855,7 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None) - @contextmanager def set_phone_prob(self, prob): + """Updates the phone probability inside context""" if hasattr(self.g2p, "phoneme_probability"): self.g2p.phoneme_probability = prob try: @@ -858,6 +866,26 @@ def set_phone_prob(self, prob): class ChinesePhonemesTokenizer(BaseTokenizer): + """Chinese phoneme-based tokenizer. + Note: This tokenizer for now covers Chinese phonemes/tones and English letters because our dataset contains + both Chinese and English graphemes. + Args: + g2p: Grapheme to phoneme module. + punct: Whether to reserve grapheme for basic punctuation or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + space: Space token as string. + silence: Silence token as string (will be disabled if it is None). + apostrophe: Whether to use apostrophe or not. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Basically, it replaces all non-unicode characters with unicode ones. + Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be + handled by g2p). + """ + # fmt: off PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally ',', '.', '!', '?', '-', @@ -865,6 +893,7 @@ class ChinesePhonemesTokenizer(BaseTokenizer): ')', '[', ']', '{', '}', ) ZH_PUNCT_LIST = list(",。?!;:、‘’“”()【】「」《》") + list(PUNCT_LIST) + # fmt: on def __init__( self, @@ -880,25 +909,6 @@ def __init__( pad_with_space=False, text_preprocessing_func=chinese_text_preprocessing, ): - """Chinese phoneme-based tokenizer. - Note: This tokenizer for now covers Chinese phonemes/tones and English letters because our dataset contains - both Chinese and English graphemes. - Args: - g2p: Grapheme to phoneme module. - punct: Whether to reserve grapheme for basic punctuation or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - space: Space token as string. - silence: Silence token as string (will be disabled if it is None). - apostrophe: Whether to use apostrophe or not. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Basically, it replaces all non-unicode characters with unicode ones. - Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be - handled by g2p). - """ tokens = [] self.space, tokens = len(tokens), tokens + [space] # Space @@ -952,8 +962,9 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): if p == space and len(ps) > 0 and ps[-1] != space: ps.append(p) # Add next phoneme or tone or ascii letter or apostrophe. - elif ((p.isalnum() or p == "'" or p in self.phoneme_list + self.tone_list + self.ascii_letter_list) and - p in tokens): + elif ( + p.isalnum() or p == "'" or p in self.phoneme_list + self.tone_list + self.ascii_letter_list + ) and p in tokens: ps.append(p) # Add punctuation elif (p in self.PUNCT_LIST) and self.punct: @@ -977,6 +988,24 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): class JapanesePhonemeTokenizer(BaseTokenizer): + """Japanese phoneme-based tokenizer. + Note: This tokenizer for now covers Japanese phonemes + Args: + g2p: Grapheme to phoneme module. + punct: Whether to reserve grapheme for basic punctuation or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + space: Space token as string. + silence: Silence token as string (will be disabled if it is None). + apostrophe: Whether to use apostrophe or not. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Basically, it replaces all non-unicode characters with unicode ones. + Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be + handled by g2p). + """ JA_PUNCT_LIST = get_ipa_punctuation_list("ja-JP") @@ -994,24 +1023,6 @@ def __init__( pad_with_space=False, text_preprocessing_func=japanese_text_preprocessing, ): - """Japanese phoneme-based tokenizer. - Note: This tokenizer for now covers Japanese phonemes - Args: - g2p: Grapheme to phoneme module. - punct: Whether to reserve grapheme for basic punctuation or not. - non_default_punct_list: List of punctuation marks which will be used instead default. - space: Space token as string. - silence: Silence token as string (will be disabled if it is None). - apostrophe: Whether to use apostrophe or not. - sep: Separation token as string. - add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), - if None then no blank in labels. - pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. - text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. - Basically, it replaces all non-unicode characters with unicode ones. - Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be - handled by g2p). - """ tokens = [] self.space, tokens = len(tokens), tokens + [space] # Space @@ -1086,43 +1097,64 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): return [self._token2id[p] for p in ps] +# TODO @xueyang: subclassing from `nemo/collections/common/tokenizers/tokenizer_spec.py::TokenizerSpec`, and/or +# adjust to reuse `nemo/collections/common/tokenizers/aggregate_tokenizer.py::AggregateTokenizer` class AggregatedTTSTokenizer: + """A simple aggregated tokenizer. Aggregates multiple tokenizers into one by combining (simply concatenating) + their tokens into one vocabulary. + Args: + tokenizers: List of tokenizers to aggregate. + tokenizer_names: List of names for each tokenizer (usually the language identifier). + """ + def __init__(self, tokenizers: List[Union[BaseTokenizer, PreTrainedTokenizerBase]], tokenizer_names: List[str]): - """A simple aggregated tokenizer. Aggregates multiple tokenizers into one by combining (simply concatenating) - their tokens into one vocabulary. - Args: - tokenizers: List of tokenizers to aggregate. - tokenizer_names: List of names for each tokenizer (usually the language identifier). - """ assert len(tokenizers) == len(tokenizer_names), "Number of tokenizers and tokenizer names must be the same." tokens = [] - toknizer_offsets = {} + tokenizer_offsets = {} tokenizer_offset = 0 self.tokenizers = {} + num_tokens_per_tokenizer = {} + tokenizer_pad_ids = {} for idx, tokenizer in enumerate(tokenizers): - self.tokenizers[tokenizer_names[idx]] = tokenizer - toknizer_offsets[tokenizer_names[idx]] = tokenizer_offset + tokenizer_name = tokenizer_names[idx] + self.tokenizers[tokenizer_name] = tokenizer + tokenizer_offsets[tokenizer_name] = tokenizer_offset if isinstance(tokenizer, BaseTokenizer): tokens.extend(tokenizer.tokens) num_tokens = len(tokenizer.tokens) + tokenizer_pad_ids[tokenizer_name] = tokenizer.pad + tokenizer_offset elif isinstance(tokenizer, PreTrainedTokenizerBase): _tokens = list(tokenizer.get_vocab().keys()) tokens.extend(_tokens) num_tokens = len(_tokens) + tokenizer_pad_ids[tokenizer_name] = tokenizer.pad_token_id + tokenizer_offset else: raise ValueError("Tokenizers must be either BaseTokenizer or HuggingFace PreTrainedTokenizerBase.") tokenizer_offset += num_tokens + num_tokens_per_tokenizer[tokenizer_name] = num_tokens self.tokens = tokens self.tokenizer_names = tokenizer_names - self.toknizer_offsets = toknizer_offsets - self.pad = self.tokenizers[tokenizer_names[0]].pad # Use the first tokenizer's pad token + self.tokenizer_offsets = tokenizer_offsets + self.vocab_size = len(tokens) + self.num_tokens_per_tokenizer = num_tokens_per_tokenizer + self.tokenizer_pad_ids = tokenizer_pad_ids + # Define aggregated token's pad value from the first tokenizer's pad value + first_tokenizer = self.tokenizers[tokenizer_names[0]] + if hasattr(first_tokenizer, "pad_token_id"): # Defined in PreTrainedTokenizerBase subclasses + self.pad = first_tokenizer.pad_token_id + elif hasattr(first_tokenizer, "pad"): # Defined in BaseTokenizer subclasses + self.pad = first_tokenizer.pad + else: + raise ValueError("AggregatedTTSTokenizer could not find a padding token in the first tokenizer") - def encode(self, text: str, tokenizer_name: str) -> List[int]: + def encode(self, text: str, tokenizer_name: str = None) -> List[int]: + """Tokenizer encode from text to tokens""" tokenizer = self.tokenizers[tokenizer_name] tokens = tokenizer.encode(text) - return [self.toknizer_offsets[tokenizer_name] + token for token in tokens] + return [self.tokenizer_offsets[tokenizer_name] + token for token in tokens] - def decode(self, tokens: List[int], tokenizer_name: str) -> str: + def decode(self, tokens: List[int], tokenizer_name: str = None) -> str: + """Tokernizer decoder from tokens to text""" tokenizer = self.tokenizers[tokenizer_name] - return tokenizer.decode([token - self.toknizer_offsets[tokenizer_name] for token in tokens]) + return tokenizer.decode([token - self.tokenizer_offsets[tokenizer_name] for token in tokens]) diff --git a/nemo/collections/speechlm2/models/duplex_s2s_model.py b/nemo/collections/speechlm2/models/duplex_s2s_model.py index e0d90a6200d9..2de158d88be9 100644 --- a/nemo/collections/speechlm2/models/duplex_s2s_model.py +++ b/nemo/collections/speechlm2/models/duplex_s2s_model.py @@ -55,7 +55,7 @@ def __init__(self, cfg: dict) -> None: self.cfg = DictConfig(cfg) setup_audio_codec(self) - self._codebook_size = self.audio_codec.vector_quantizer.codebook_size_per_group + self._codebook_size = self.audio_codec.vector_quantizer.codebook_size self._num_codebooks = self.audio_codec.vector_quantizer.num_groups # We load the pretrained HF LLM using "ForCausalLM" variant so that we can obtain the diff --git a/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py b/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py index df246c6640b5..3605e886b3e4 100644 --- a/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py +++ b/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py @@ -56,7 +56,7 @@ def __init__(self, cfg: dict) -> None: self.cfg = DictConfig(cfg) setup_audio_codec(self) - self._codebook_size = self.audio_codec.vector_quantizer.codebook_size_per_group + self._codebook_size = self.audio_codec.vector_quantizer.codebook_size self._num_codebooks = self.audio_codec.vector_quantizer.num_groups # We load the pretrained HF LLM using "ForCausalLM" variant so that we can obtain the diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index 63aedb3eea46..aace2f198f26 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -37,14 +37,13 @@ ) from nemo.core.classes import Dataset from nemo.utils import logging -from nemo.utils.decorators import experimental @dataclass class DatasetMeta: manifest_path: Path audio_dir: Path - feature_dir: Path + feature_dir: Path = None sample_weight: float = 1.0 tokenizer_names: List[str] = None @@ -61,7 +60,6 @@ class DatasetSample: tokenizer_names: List[str] = None -@experimental class TextToSpeechDataset(Dataset): """ Class for processing and loading text to speech training examples. @@ -338,7 +336,7 @@ class MagpieTTSDataset(TextToSpeechDataset): max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration' will be ignored. volume_norm: Whether to apply volume normalization to loaded audio. - codec_model_downsample_factor: Downsample factor of the codec model (Num samples in waveform per codec frame). + codec_model_samples_per_frame: Num samples in waveform per codec frame (codec downsample factor). bos_id: Text BOS token id. eos_id: Text EOS token id. audio_bos_id: Audio BOS token id. @@ -355,6 +353,8 @@ class MagpieTTSDataset(TextToSpeechDataset): pad_context_text_to_max_duration: Whether to pad context text to max context audio frames. context_duration_min: Minimum duration of context audio in seconds. context_duration_max: Maximum duration of context audio in seconds. + text_context_remapping: Dict defining mapping of multiple text contexts to a single text context. + text_context_remapping_prob: Probability of remapping the original text context to a remapped text context. """ def __init__( @@ -365,7 +365,7 @@ def __init__( min_duration: Optional[float] = None, max_duration: Optional[float] = None, volume_norm: bool = True, - codec_model_downsample_factor: int = None, + codec_model_samples_per_frame: int = None, bos_id: int = None, eos_id: int = None, audio_bos_id: int = None, @@ -379,9 +379,12 @@ def __init__( tokenizer_config=None, load_16khz_audio: bool = True, use_text_conditioning_tokenizer: bool = False, + text_conditioning_tokenizer_name: str = None, pad_context_text_to_max_duration: bool = False, context_duration_min: float = 3.0, context_duration_max: float = 10.0, + text_context_remapping: Dict[str, str] = None, + text_context_remapping_prob: float = 0.0, ): super().__init__( dataset_meta=dataset_meta, @@ -396,14 +399,14 @@ def __init__( max_duration=max_duration, volume_norm=volume_norm, ) - self.bos_id = bos_id + self.bos_id = bos_id # TODO @xueyang: this should be removed since no other places used it. self.eos_id = eos_id self.audio_bos_id = audio_bos_id self.audio_eos_id = audio_eos_id self.context_audio_bos_id = context_audio_bos_id self.context_audio_eos_id = context_audio_eos_id self.num_audio_codebooks = num_audio_codebooks - self.codec_model_downsample_factor = codec_model_downsample_factor + self.codec_model_samples_per_frame = codec_model_samples_per_frame self.include_align_prior = prior_scaling_factor is not None self.prior_scaling_factor = prior_scaling_factor self.load_cached_codes_if_available = load_cached_codes_if_available @@ -412,16 +415,16 @@ def __init__( self.text_tokenizer = None # Assigned in worker_init_fn in model file self.load_16khz_audio = load_16khz_audio self.use_text_conditioning_tokenizer = use_text_conditioning_tokenizer - self.text_conditioning_tokenizer = ( - None # Assigned in worker_init_fn in model file if use_text_conditioning_tokenizer is True - ) + self.text_conditioning_tokenizer_name = text_conditioning_tokenizer_name self.pad_context_text_to_max_duration = pad_context_text_to_max_duration self.context_duration_min = context_duration_min self.context_duration_max = context_duration_max + self.text_context_remapping = text_context_remapping + self.text_context_remapping_prob = text_context_remapping_prob def get_num_audio_samples_to_slice(self, duration, sample_rate): - num_codec_frames = int(duration * sample_rate / self.codec_model_downsample_factor) - num_audio_samples = num_codec_frames * self.codec_model_downsample_factor + num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) + num_audio_samples = num_codec_frames * self.codec_model_samples_per_frame return num_audio_samples def __getitem__(self, index): @@ -445,13 +448,16 @@ def __getitem__(self, index): audio_codes_path = data.manifest_entry['target_audio_codes_path'] audio_codes = torch.load(audio_codes_path).long() # (C, T) spec_len = audio_codes.shape[1] + 1 # +1 for EOS - auidio_bos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_bos_id, dtype=audio_codes.dtype) + audio_bos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_bos_id, dtype=audio_codes.dtype) audio_eos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_eos_id, dtype=audio_codes.dtype) - audio_codes = torch.cat([auidio_bos_tensor, audio_codes, audio_eos_tensor], dim=1) + audio_codes = torch.cat([audio_bos_tensor, audio_codes, audio_eos_tensor], dim=1) audio_codes_len = audio_codes.shape[1] example['audio_codes'] = audio_codes example['audio_codes_len'] = audio_codes_len example['audio_filepath'] = audio_codes_path + if 'audio_filepath' in data.manifest_entry: + # If audio_filepath is available, then use the actual audio file path. + example['audio_filepath'] = data.manifest_entry['audio_filepath'] else: # Only load audio if codes are not available audio_array, _, audio_filepath_rel = load_audio( @@ -464,14 +470,14 @@ def __getitem__(self, index): # Pad audio to be multiple of downsample factor audio = torch.nn.functional.pad( audio, - (0, self.codec_model_downsample_factor - (audio.shape[0] % self.codec_model_downsample_factor)), + (0, self.codec_model_samples_per_frame - (audio.shape[0] % self.codec_model_samples_per_frame)), value=0, ) audio_len = audio.shape[0] example['audio_filepath'] = data.manifest_entry['audio_filepath'] example['audio'] = audio example['audio_len'] = audio_len - spec_len = int(audio_len / self.codec_model_downsample_factor) + 1 # +1 for EOS + spec_len = int(audio_len / self.codec_model_samples_per_frame) + 1 # +1 for EOS if self.load_cached_codes_if_available and 'context_audio_codes_path' in data.manifest_entry: context_audio_codes_path = data.manifest_entry['context_audio_codes_path'] @@ -479,7 +485,7 @@ def __getitem__(self, index): # Sample random duration between self.context_duration_min and self.context_duration_max _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) _num_frames_to_slice = int( - _context_duration_to_slice * self.sample_rate / self.codec_model_downsample_factor + _context_duration_to_slice * self.sample_rate / self.codec_model_samples_per_frame ) if _num_frames_to_slice < context_audio_codes.shape[1]: start_idx = random.randint(0, context_audio_codes.shape[1] - _num_frames_to_slice) @@ -542,7 +548,9 @@ def __getitem__(self, index): example['context_audio_codes_len'] = context_audio_codes_len else: # @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes - context_audio = torch.zeros(self.codec_model_downsample_factor, dtype=torch.float32) + # @blisc: Added a +1. If we send in exactly 882 samples, then a conv layer complains about padding. + # Adding 883 works. This occurs when we use text context during inference. + context_audio = torch.zeros(self.codec_model_samples_per_frame + 1, dtype=torch.float32) context_audio_len = context_audio.shape[0] example['context_audio'] = context_audio example['context_audio_len'] = context_audio_len @@ -576,17 +584,22 @@ def __getitem__(self, index): if self.use_text_conditioning_tokenizer: if 'context_text' in data.manifest_entry: - context_tokens = self.text_conditioning_tokenizer(data.manifest_entry['context_text'])['input_ids'] + context_text = data.manifest_entry['context_text'] + if self.text_context_remapping is not None and context_text in self.text_context_remapping: + if self.dataset_type == 'train' and random.random() < self.text_context_remapping_prob: + # Only remap during training. Give the exact text context during inference. + context_text = self.text_context_remapping[context_text] + context_tokens = self.text_tokenizer.encode(context_text, self.text_conditioning_tokenizer_name) example['has_text_context'] = True else: - context_tokens = self.text_conditioning_tokenizer("[NO TEXT CONTEXT]")['input_ids'] + context_tokens = self.text_tokenizer.encode("[NO TEXT CONTEXT]", self.text_conditioning_tokenizer_name) example['has_text_context'] = False if self.pad_context_text_to_max_duration: _required_len = ( - int(self.context_duration_max * self.sample_rate / self.codec_model_downsample_factor) + 2 + int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 ) # +2 for BOS and EOS if len(context_tokens) < _required_len: - _pad_id = self.text_conditioning_tokenizer.pad_token_id + _pad_id = self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name] context_tokens += [_pad_id] * (_required_len - len(context_tokens)) else: context_tokens = context_tokens[:_required_len] @@ -603,7 +616,15 @@ def __getitem__(self, index): align_prior = torch.tensor(align_prior, dtype=torch.float32) example["align_prior"] = align_prior - example['raw_text'] = data.text + if "original_text" in data.manifest_entry: + # Raw Text is used as the GT for CER/WER computation in DPO pref data generation + # and GRPO reward setup. For manifests in which the 'text' field is phonemized, + # we use the 'original_text' field as the raw text. Otherwise, we use the regular text field. + example['raw_text'] = data.manifest_entry['original_text'] + else: + example['raw_text'] = data.text + + example['language'] = data.manifest_entry.get('language', 'en') if "reward" in data.manifest_entry: example["reward"] = data.manifest_entry["reward"] @@ -631,10 +652,12 @@ def collate_fn(self, batch: List[dict]): context_has_text_context_list = [] reward_list = [] raw_text_list = [] + language_list = [] for example in batch: dataset_name_list.append(example["dataset_name"]) audio_filepath_list.append(example["audio_filepath"]) raw_text_list.append(example["raw_text"]) + language_list.append(example["language"]) token_list.append(example["tokens"]) token_len_list.append(example["text_len"]) @@ -677,6 +700,7 @@ def collate_fn(self, batch: List[dict]): batch_dict = { "dataset_names": dataset_name_list, "raw_texts": raw_text_list, + "languages": language_list, "audio_filepaths": audio_filepath_list, "text": batch_tokens, "text_lens": batch_token_len, @@ -713,6 +737,7 @@ def collate_fn(self, batch: List[dict]): if len(context_audio_codes_list) > 0: batch_context_audio_codes_len = torch.IntTensor(context_audio_codes_len_list) context_audio_codes_max_len = int(batch_context_audio_codes_len.max().item()) + # TODO @xueyang: verify if batch_context_audio_codes are integer. batch_context_audio_codes = stack_tensors(context_audio_codes_list, max_lens=[context_audio_codes_max_len]) batch_dict['context_audio_codes'] = batch_context_audio_codes batch_dict['context_audio_codes_lens'] = batch_context_audio_codes_len @@ -720,6 +745,7 @@ def collate_fn(self, batch: List[dict]): if self.use_text_conditioning_tokenizer: batch_context_text_tokens_len = torch.IntTensor(context_text_tokens_len_list) context_text_tokens_max_len = int(batch_context_text_tokens_len.max().item()) + # TODO @xueyang: potential bugs if self.tokenizer.pad is not 0.0. verify if batch_context_text_tokens are integer. batch_context_text_tokens = stack_tensors(context_text_tokens_list, max_lens=[context_text_tokens_max_len]) batch_dict['context_text_tokens'] = batch_context_text_tokens batch_dict['context_text_tokens_lens'] = batch_context_text_tokens_len diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py new file mode 100644 index 000000000000..6090c4b164f0 --- /dev/null +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -0,0 +1,498 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import re +from typing import Dict, List, Union + +import numpy as np +import torch +from hydra.utils import instantiate +from lhotse import CutSet +from lhotse.dataset.collation import collate_matrices, collate_vectors +from omegaconf import DictConfig +from transformers import AutoTokenizer, T5Tokenizer + +from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer +from nemo.collections.tts.parts.utils.tts_dataset_utils import ( + beta_binomial_prior_distribution, + normalize_volume, + stack_tensors, +) +from nemo.utils import logging + + +def setup_tokenizers(all_tokenizers_config, mode='train'): + # Being used in both model and worker_init_fn, so it is defined here + # Returns two tokenizers: one for TTS transcript and one for conditioning text (if needed) + tokenizers = [] + tokenizer_names = [] + for tokenizer_name in all_tokenizers_config: + tokenizer_config = all_tokenizers_config[tokenizer_name] + if tokenizer_config._target_ == 'AutoTokenizer': + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.pretrained_model) + elif tokenizer_config._target_ == 'T5Tokenizer': + tokenizer = T5Tokenizer.from_pretrained(tokenizer_config.pretrained_model) + else: + text_tokenizer_kwargs = {} + if "g2p" in tokenizer_config: + text_tokenizer_kwargs["g2p"] = instantiate(tokenizer_config.g2p) + tokenizer = instantiate(tokenizer_config, **text_tokenizer_kwargs) + # TODO @xueyang: is it really necessary to set phone probability to 1.0 for test mode? + if mode == 'test' and hasattr(tokenizer, "set_phone_prob"): + tokenizer.set_phone_prob(1.0) + tokenizers.append(tokenizer) + tokenizer_names.append(tokenizer_name) + + aggregated_tokenizer = AggregatedTTSTokenizer(tokenizers, tokenizer_names) # TTS Transcript tokenizer + + return aggregated_tokenizer + + +def check_speaker_format(item: str): + # enforce the format as example like "| Language:en Dataset:HiFiTTS Speaker:9136_other |". + pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" + return bool(re.match(pattern, item)) + + +class MagpieTTSLhotseDataset(torch.utils.data.Dataset): + """ + A PyTorch Dataset for loading and processing Text-to-Speech data for + MagpieTTS models using Lhotse CutSets, specifically designed for datasets + with text or audio context. But either context can be optional. + + This dataset expects Lhotse Cut objects where each cut represents a + target utterance along with its preceding context. Context can be + audio (preferred) or text. It handles loading either pre-computed audio + codes or raw audio waveforms, applying volume normalization, and tokenizing + text transcripts. Context audio/codes are sliced or repeated to fit within + a specified duration range. Optionally, it loads 16kHz audio suitable for + speaker verification models and calculates alignment priors. + + Tokenizers (for target text and optional context text) are initialized lazily + within each dataloader worker process upon first access. + + Args: + sample_rate (int): Target sample rate for loading audio. Audio will be + resampled if necessary. + volume_norm (bool): If True, applies peak volume normalization to audio + waveforms. Defaults to True. + codec_model_samples_per_frame (int): The total downsampling factor of the + audio codec model used to generate codes. Used for padding audio + and calculating number of codec frames. + audio_bos_id (int): Token ID representing the beginning-of-sequence (BOS) for + target audio codes. + audio_eos_id (int): Token ID representing the end-of-sequence (EOS) for target + audio codes. + context_audio_bos_id (int): Token ID representing the beginning-of-sequence (BOS) + for context audio codes. + context_audio_eos_id (int): Token ID representing the end-of-sequence (EOS)for + context audio codes. + num_audio_codebooks (int): Number of codebooks used by the audio codec model. + Needed for creating dummy context codes if necessary. + prior_scaling_factor (Optional[float]): Scaling factor for the beta-binomial + alignment prior calculation. If None, priors are not computed. Defaults to None. + load_cached_codes_if_available (bool): If True, attempts to load pre-computed + audio codes from custom fields in the Lhotse Cut (e.g., 'codes_21fpsCausalDecoder', + 'context_codes_21fpsCausalDecoder'). Falls back to loading audio if codes + are not found. Defaults to True. + dataset_type (str): Specifies the mode ('train' or 'test'), mainly affecting + tokenizer settings like phoneme probability. Defaults to 'train'. + load_16khz_audio (bool): If True, loads 16kHz audio suitable for speaker + verification models. It prioritizes context audio ('context_recording' field) + if available, otherwise uses the target audio ('recording' field). + Defaults to True. + pad_context_text_to_max_duration (bool): If True and `use_text_conditioning_tokenizer` + is True, pads the tokenized context text to a length derived from + `context_duration_max`. Defaults to False. + context_duration_min (float): Minimum duration (in seconds) for the context + audio/codes. Context shorter than this will be repeated. Defaults to 3.0. + context_duration_max (float): Maximum duration (in seconds) for the context + audio/codes. Context longer than this will be sliced randomly. Defaults to 10.0. + use_text_conditioning_tokenizer (bool): If True, enables processing of context + text using a separate tokenizer (currently T5Tokenizer). Expects context text + in `cut.supervisions[0].custom['context_text']`. Defaults to False. + tokenizer_config (Optional[DictConfig]): Configuration for the text tokenizers. + Used for lazy initialization within workers. Must be provided if tokenizers + are not set externally. Defaults to None. + text_context_remapping: Dict defining mapping of multiple text contexts to a single text context. + text_context_remapping_prob: Probability of remapping the original text context to a remapped text context. + """ + + def __init__( + self, + sample_rate: int, + volume_norm: bool = True, + codec_model_samples_per_frame: int = None, + audio_bos_id: int = None, + audio_eos_id: int = None, + context_audio_bos_id: int = None, + context_audio_eos_id: int = None, + num_audio_codebooks: int = None, + prior_scaling_factor: float = None, + load_cached_codes_if_available: bool = True, + dataset_type: str = 'train', + load_16khz_audio: bool = True, + pad_context_text_to_max_duration: bool = False, + context_duration_min: float = 3.0, + context_duration_max: float = 10.0, + use_text_conditioning_tokenizer: bool = False, + text_conditioning_tokenizer_name: str = None, + tokenizer_config: DictConfig = None, + text_context_remapping: Dict[str, str] = None, + text_context_remapping_prob: float = 0.0, + ): + super().__init__() + self.sample_rate = sample_rate + self.volume_norm = volume_norm + self.audio_bos_id = audio_bos_id + self.audio_eos_id = audio_eos_id + self.context_audio_bos_id = context_audio_bos_id + self.context_audio_eos_id = context_audio_eos_id + + self.codec_model_samples_per_frame = codec_model_samples_per_frame + self.num_audio_codebooks = num_audio_codebooks + + self.include_align_prior = prior_scaling_factor is not None + self.prior_scaling_factor = prior_scaling_factor + self.load_cached_codes_if_available = load_cached_codes_if_available + self.dataset_type = dataset_type # 'train' or 'test' + self.load_16khz_audio = load_16khz_audio + self.use_text_conditioning_tokenizer = use_text_conditioning_tokenizer + self.text_conditioning_tokenizer_name = text_conditioning_tokenizer_name + self.pad_context_text_to_max_duration = pad_context_text_to_max_duration + self.context_duration_min = context_duration_min + self.context_duration_max = context_duration_max + self.tokenizer_config = tokenizer_config + self.text_tokenizer = None + self.text_context_remapping = text_context_remapping + self.text_context_remapping_prob = text_context_remapping_prob + + def get_num_audio_samples_to_slice(self, duration, sample_rate): + num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) + num_audio_samples = num_codec_frames * self.codec_model_samples_per_frame + return num_audio_samples + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: + # layze initialize tokenizers. The first time any specific worker + # process calls this function, on its copy of the dataset, the + # tokenizers are created for that worker. All subsequent calls + # to this function will reuse the tokenizers. This equivilent to + # the `worker_init_fn` in MagpieTTSModel. + if self.text_tokenizer is None: + # First time this worker is accessing the dataset, initialize the + # tokenizers. If called by the main process (num_workers=0), worker_info will be None. + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + logging.info(f"Worker {worker_id} initializing tokenizers...") + self.text_tokenizer = setup_tokenizers( + all_tokenizers_config=self.tokenizer_config, + mode=self.dataset_type, + ) + self.bos_id = len(self.text_tokenizer.tokens) + self.eos_id = self.bos_id + 1 + self.pad_id = self.text_tokenizer.pad + + # define list to store batched information + dataset_name_list = [] + audio_list = [] + audio_len_list = [] + audio_list_16khz = [] + audio_len_list_16khz = [] + token_list = [] + token_len_list = [] + prior_list = [] + audio_codes_list = [] + audio_codes_len_list = [] + context_audio_list = [] + context_audio_len_list = [] + context_audio_codes_list = [] + context_audio_codes_len_list = [] + context_text_tokens_list = [] + context_text_tokens_len_list = [] + context_has_text_context_list = [] + reward_list = [] + raw_text_list = ( + [] + ) # raw text here is the string of normalized text or text stored in the supervision segment. Used to distinguish from text tokens. + for cut in cuts: + speaker = cut.supervisions[0].speaker + if not check_speaker_format(speaker): + raise ValueError(f"Invalid format in cut.supervisions[0].speaker: {speaker}") + dataset_name = speaker.strip().split()[2].split(":")[-1] + dataset_name_list.append(dataset_name) + + # target audio or target codes + if self.load_cached_codes_if_available and cut.has_custom("target_codes"): + # TODO @xueyang: applying Tensor.long(), i.e. torch.int64, is not necessary. + + # Note that we have segmented the audio according to offset and duration so that the audio codes should + # not specify start and duration again when calling TemporalArray.load(start, duration). Ensure start + # and duration are None to the load function. + audio_codes = torch.from_numpy(cut.target_codes.load()).long() # (C, T) + spec_len = audio_codes.shape[1] + 1 # +1 for EOS + audio_bos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_bos_id, dtype=audio_codes.dtype) + audio_eos_tensor = torch.full((audio_codes.shape[0], 1), self.audio_eos_id, dtype=audio_codes.dtype) + audio_codes = torch.cat([audio_bos_tensor, audio_codes, audio_eos_tensor], dim=1) + audio_codes_len = audio_codes.shape[1] + audio_codes_list.append(audio_codes.T) # transpose to (T, C) to use collate_matrices to process batch. + audio_codes_len_list.append(audio_codes_len) + else: + # Only load audio if codes are not available + audio_array = cut.recording.resample(self.sample_rate).load_audio().squeeze(0) + if self.volume_norm: + audio_array = normalize_volume(audio_array) + audio = torch.from_numpy(audio_array) + # Pad audio to be multiple of downsample factor + audio = torch.nn.functional.pad( + audio, + (0, self.codec_model_samples_per_frame - (audio.shape[0] % self.codec_model_samples_per_frame)), + value=0, + ) + audio_len = audio.shape[0] + spec_len = int(audio_len / self.codec_model_samples_per_frame) + 1 # +1 for EOS + audio_list.append(audio) + audio_len_list.append(audio_len) + + # context audio or context codes + if self.load_cached_codes_if_available and cut.has_custom("context_codes"): + # TODO @xueyang: applying Tensor.long(), i.e. torch.int64, is not necessary. + + # Note that we have segmented the audio according to offset and duration so that the audio codes should + # not specify start and duration again when calling TemporalArray.load(start, duration). Ensure start + # and duration are None to the load function. + context_audio_codes = torch.from_numpy(cut.context_codes.load()).long() # (8, T) + # Sample random duration between self.context_duration_min and self.context_duration_max + _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) + _num_frames_to_slice = int( + _context_duration_to_slice * self.sample_rate / self.codec_model_samples_per_frame + ) + if _num_frames_to_slice < context_audio_codes.shape[1]: + start_idx = random.randint(0, context_audio_codes.shape[1] - _num_frames_to_slice) + context_audio_codes = context_audio_codes[:, start_idx : start_idx + _num_frames_to_slice] + else: + # Repeat the audio if it is shorter than the desired duration + _num_repeats = int(np.ceil(_num_frames_to_slice / context_audio_codes.shape[1])) + # context_audio_codes is a tensor of shape (num_codebooks, T) + context_audio_codes_repeated = context_audio_codes.repeat(1, _num_repeats) + context_audio_codes = context_audio_codes_repeated[:, :_num_frames_to_slice] + + context_bos_tensor = torch.full( + (context_audio_codes.shape[0], 1), self.context_audio_bos_id, dtype=context_audio_codes.dtype + ) + context_eos_tensor = torch.full( + (context_audio_codes.shape[0], 1), self.context_audio_eos_id, dtype=context_audio_codes.dtype + ) + context_audio_codes = torch.cat([context_bos_tensor, context_audio_codes, context_eos_tensor], dim=1) + context_audio_codes_len = context_audio_codes.shape[1] + context_audio_codes_list.append( + context_audio_codes.T + ) # transpose to (T, 8) in order to use collate_matrices to process batch. + context_audio_codes_len_list.append(context_audio_codes_len) + elif cut.has_custom("context_recording"): + # Only load audio if codes are not available + context_audio_array = cut.context_recording.resample(self.sample_rate).load_audio().squeeze(0) + if self.volume_norm: + context_audio_array = normalize_volume(context_audio_array) + _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) + _num_samples_to_slice = self.get_num_audio_samples_to_slice( + _context_duration_to_slice, self.sample_rate + ) + if _num_samples_to_slice < len(context_audio_array): + start_idx = random.randint(0, len(context_audio_array) - _num_samples_to_slice) + context_audio_array = context_audio_array[start_idx : start_idx + _num_samples_to_slice] + else: + # Repeat the audio if it is shorter than the desired duration + _num_repeats = int(np.ceil(_num_samples_to_slice / len(context_audio_array))) + context_audio_array = np.tile(context_audio_array, _num_repeats) + context_audio_array = context_audio_array[:_num_samples_to_slice] + context_audio = torch.from_numpy(context_audio_array) + context_audio_len = context_audio.shape[0] + context_audio_list.append(context_audio) + context_audio_len_list.append(context_audio_len) + else: + # We always want to have context_audio_codes if available for multi-encoder model. These are ignored for single-encoder model. + # If context audio is not available, just use a dummy context_audio_codes + # (Will be used in text context scenario) + # TODO @xueyang: verified that this block should cover below 3 conditions which were handled well. + # 1. load_cached_codes_if_available and ["context_audio_codes_path", "context_audio_filepath"] not in data.manifest_entry; + # assign to example["context_audio_codes"] and example["context_audio_codes_len"] + # 2. load_cached_codes_if_available is not True and "context_audio_codes_path" in data.manifest_entry; + # assign to example["context_audio"] and example["context_audio_len"] + # 3. load_cached_codes_if_available is not True and ["context_audio_codes_path", "context_audio_filepath"] not in data.manifest_entry; + # assign to example["context_audio"] and example["context_audio_len"] + if self.load_cached_codes_if_available: + context_bos_tensor = torch.full( + (self.num_audio_codebooks, 1), self.context_audio_bos_id, dtype=torch.int32 + ) + context_eos_tensor = torch.full( + (self.num_audio_codebooks, 1), self.context_audio_eos_id, dtype=torch.int32 + ) + context_audio_codes = torch.cat([context_bos_tensor, context_eos_tensor], dim=1) + context_audio_codes_len = context_audio_codes.shape[1] + context_audio_codes_list.append( + context_audio_codes.T + ) # transpose to (T, C) to use collate_matrices to process batch. + context_audio_codes_len_list.append(context_audio_codes_len) + else: + # @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes + context_audio = torch.zeros(self.codec_model_samples_per_frame, dtype=torch.float32) + context_audio_len = context_audio.shape[0] + context_audio_list.append(context_audio) + context_audio_len_list.append(context_audio_len) + + if self.load_16khz_audio: + if cut.has_custom("context_recording"): + # use context audio for SV model + audio_array_16khz = cut.context_recording.resample(16_000).load_audio().squeeze(0) + if self.volume_norm: + audio_array_16khz = normalize_volume(audio_array_16khz) + else: + # Otherwise, load the target audio for SV model. + audio_array_16khz = cut.recording.resample(16_000).load_audio().squeeze(0) + if self.volume_norm: + audio_array_16khz = normalize_volume(audio_array_16khz) + _context_duration_to_slice = random.uniform(self.context_duration_min, self.context_duration_max) + _num_samples_to_slice = int(_context_duration_to_slice * 16_000) + if _num_samples_to_slice < len(audio_array_16khz): + start_idx = random.randint(0, len(audio_array_16khz) - _num_samples_to_slice) + audio_array_16khz = audio_array_16khz[start_idx : start_idx + _num_samples_to_slice] + audio_16khz = torch.from_numpy(audio_array_16khz) + audio_len_16khz = audio_16khz.shape[0] + audio_list_16khz.append(audio_16khz) + audio_len_list_16khz.append(audio_len_16khz) + + if self.use_text_conditioning_tokenizer: + if cut.supervisions[0].has_custom("context_text"): + context_text = cut.supervisions[0].context_text + if self.text_context_remapping is not None and context_text in self.text_context_remapping: + if self.dataset_type == 'train' and random.random() < self.text_context_remapping_prob: + # Only remap during training. Give the exact text context during inference. + context_text = self.text_context_remapping[context_text] + context_text_tokens = self.text_tokenizer.encode( + context_text, tokenizer_name=self.text_conditioning_tokenizer_name + ) + has_text_context = True + else: + context_text_tokens = self.text_tokenizer.encode( + "[NO TEXT CONTEXT]", tokenizer_name=self.text_conditioning_tokenizer_name + ) + has_text_context = False + if self.pad_context_text_to_max_duration: + _required_len = ( + int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 + ) # +2 for BOS and EOS + if len(context_text_tokens) < _required_len: + _pad_id = self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name] + context_text_tokens += [_pad_id] * (_required_len - len(context_text_tokens)) + else: + # TODO @xueyang: It seems counter intuition if trimming the text context tokens to the required + # context length. For example, the context_tokens after trimming may correspond to the partial + # context_text like "Speaker and Emotion: | Language:en Dataset" where the following string is trimmed: ":Riva Speaker:Rodney_DROP |". + context_text_tokens = context_text_tokens[:_required_len] + context_text_tokens = torch.tensor(context_text_tokens, dtype=torch.int32) + context_text_tokens_len = context_text_tokens.shape[0] + context_text_tokens_list.append(context_text_tokens) + context_text_tokens_len_list.append(context_text_tokens_len) + context_has_text_context_list.append(has_text_context) + + # tokenize transcript + # there may exist "normalized_text" in the suprvisionsegement. Prioritize it over "text" if available. + if cut.supervisions[0].has_custom("normalized_text"): + text_str = cut.supervisions[0].normalized_text + else: + text_str = cut.supervisions[0].text + raw_text_list.append(text_str) + if cut.has_custom("tokenizer_names"): + # Pick a random tokenizer from the list of tokenizers + tokenizer_name = random.choice(cut.tokenizer_names) + else: + tokenizer_name = "english_phoneme" # Default to english phoneme tokenizer + tokens = self.text_tokenizer.encode(text=text_str, tokenizer_name=tokenizer_name) + tokens = tokens + [self.eos_id] # Not adding BOS id + tokens = torch.tensor(tokens, dtype=torch.int32) + text_len = tokens.shape[0] + token_list.append(tokens) + token_len_list.append(text_len) + + if self.include_align_prior: + align_prior = beta_binomial_prior_distribution( + phoneme_count=text_len, mel_count=spec_len, scaling_factor=self.prior_scaling_factor + ) + align_prior = torch.tensor(align_prior, dtype=torch.float32) + prior_list.append(align_prior) + + if cut.supervisions[0].has_custom("reward"): + reward = cut.supervisions[0].reward + reward_list.append(reward) + + # collate vectors and matrices here. + batch_dict = { + "dataset_names": dataset_name_list, + "raw_texts": raw_text_list, + "text": collate_vectors(token_list, padding_value=self.pad_id), # (B, max_len) + "text_lens": torch.IntTensor(token_len_list), + } + + # audio for SV. + if len(audio_list_16khz) > 0: + batch_dict["audio_16khz"] = collate_vectors(audio_list_16khz, padding_value=0.0) + batch_dict["audio_lens_16khz"] = torch.IntTensor(audio_len_list_16khz) + + # target audio and codes + if len(audio_list) > 0: + batch_dict["audio"] = collate_vectors(audio_list, padding_value=0.0) + batch_dict["audio_lens"] = torch.IntTensor(audio_len_list) + if len(audio_codes_list) > 0: + # transpose back to (B, 8, T) from (B, T, 8). + batch_dict["audio_codes"] = collate_matrices(audio_codes_list, padding_value=0).transpose(1, 2) + batch_dict["audio_codes_lens"] = torch.IntTensor(audio_codes_len_list) + + # context audio and codes + if len(context_audio_list) > 0: + batch_dict["context_audio"] = collate_vectors(context_audio_list, padding_value=0.0) + batch_dict["context_audio_lens"] = torch.IntTensor(context_audio_len_list) + if len(context_audio_codes_list) > 0: + # transpose back to (B, 8, T) from (B, T, 8). + batch_dict["context_audio_codes"] = collate_matrices(context_audio_codes_list, padding_value=0).transpose( + 1, 2 + ) + batch_dict["context_audio_codes_lens"] = torch.IntTensor(context_audio_codes_len_list) + + if self.use_text_conditioning_tokenizer: + batch_dict['context_text_tokens'] = collate_vectors( + tensors=context_text_tokens_list, + padding_value=self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name], + ) + batch_dict['context_text_tokens_lens'] = torch.IntTensor(context_text_tokens_len_list) + batch_dict['has_text_context'] = torch.BoolTensor(context_has_text_context_list) + + if self.include_align_prior: + spec_max_len = max([prior.shape[0] for prior in prior_list]) + text_max_len = max([prior.shape[1] for prior in prior_list]) + batch_dict["align_prior_matrix"] = stack_tensors(prior_list, max_lens=[text_max_len, spec_max_len]) + + if len(reward_list) > 0: + batch_dict['rewards'] = torch.FloatTensor(reward_list) + + # Assert only ONE of context_audio or context_audio_codes in the batch + assert ('audio' in batch_dict) ^ ('audio_codes' in batch_dict) + + # Assert only ONE of context_audio or context_audio_codes in the batch + if 'context_audio' in batch_dict: + assert 'context_audio_codes' not in batch_dict + if 'context_audio_codes' in batch_dict: + assert 'context_audio' not in batch_dict + + return batch_dict diff --git a/nemo/collections/tts/data/vocoder_dataset.py b/nemo/collections/tts/data/vocoder_dataset.py index 55fa9d678edd..84be82f47959 100644 --- a/nemo/collections/tts/data/vocoder_dataset.py +++ b/nemo/collections/tts/data/vocoder_dataset.py @@ -36,7 +36,6 @@ from nemo.core.classes import Dataset, IterableDataset from nemo.utils import logging from nemo.utils import webdataset as wds -from nemo.utils.decorators import experimental from nemo.utils.distributed import webdataset_split_by_workers VALID_FILE_FORMATS = ';'.join(['wav', 'mp3', 'flac', 'opus'] + [fmt.lower() for fmt in valid_sf_formats.keys()]) @@ -111,7 +110,6 @@ def preprocess_manifest( return samples, sample_weights -@experimental class VocoderDataset(Dataset): """ Class for processing and loading Vocoder training examples. diff --git a/nemo/collections/tts/g2p/models/i18n_ipa.py b/nemo/collections/tts/g2p/models/i18n_ipa.py index ed0569eac98d..6a2927db4be4 100644 --- a/nemo/collections/tts/g2p/models/i18n_ipa.py +++ b/nemo/collections/tts/g2p/models/i18n_ipa.py @@ -28,10 +28,8 @@ from nemo.collections.tts.g2p.models.base import BaseG2p from nemo.collections.tts.g2p.utils import GRAPHEME_CASE_MIXED, GRAPHEME_CASE_UPPER, set_grapheme_case from nemo.utils import logging -from nemo.utils.decorators import experimental -@experimental class IpaG2p(BaseG2p): # fmt: off STRESS_SYMBOLS = ["ˈ", "ˌ"] diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py index 6db3e30595c6..b970261e30c5 100644 --- a/nemo/collections/tts/losses/audio_codec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -230,8 +230,8 @@ def output_types(self): @typecheck() def forward(self, audio_real, audio_gen, audio_len): spec_len = (audio_len // self.hop_length) + 1 - spec_real = self._compute_spectrogram(audio=audio_real, spec_len=spec_len) - spec_gen = self._compute_spectrogram(audio=audio_gen, spec_len=spec_len) + spec_real = self._compute_spectrogram(audio=audio_real.float(), spec_len=spec_len).to(audio_gen.dtype) + spec_gen = self._compute_spectrogram(audio=audio_gen.float(), spec_len=spec_len).to(audio_gen.dtype) loss = self.loss_fn(predicted=spec_gen, target=spec_real, target_len=spec_len) return loss @@ -512,3 +512,176 @@ def forward(self, disc_scores_real, disc_scores_gen): loss /= len(disc_scores_real) return loss + + +class MMDLoss(Loss): + """ + Maximum mean discrepancy (MMD) loss, as defined in https://arxiv.org/abs/2406.02315 + + Args: + kernel_radii: List of radii for Gaussian kernels + loss_scale: Constant to multiply loss by + """ + + def __init__(self, kernel_radii=(0.1, 1, 5, 10, 20, 50), loss_scale=1.0): + super().__init__() + self.kernel_radii = kernel_radii + self.loss_scale = loss_scale + + @staticmethod + def _exp_kernel(dxx, r): + return torch.exp((-0.5 / r) * dxx).sum() + + @staticmethod + def _shuffle_codebooks(x): + B, C, _ = x.size() + x_shuffled = torch.zeros_like(x) + for c in range(C): + batch_perm = torch.randperm(B, device=x.device) + x_shuffled[:, c, :] = x[batch_perm, c, :] + return x_shuffled + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'C', 'D'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, inputs): + B, C, D = inputs.size() + + x = inputs + x_mean = x.mean(dim=(0,), keepdim=True) + x_stdev = torch.sqrt(x.var(dim=(0,), keepdim=True) + 1e-8) + x = (x - x_mean) / x_stdev + y = self._shuffle_codebooks(x) + + # [B, C * D] + x = x.reshape([B, C * D]) + y = y.reshape([B, C * D]) + + # [B, B] + xx = torch.mm(x, x.t()) + yy = torch.mm(y, y.t()) + zz = torch.mm(x, y.t()) + + rx = xx.diag().unsqueeze(0).expand_as(xx) + ry = yy.diag().unsqueeze(0).expand_as(yy) + + dxx = rx.t() + rx - 2.0 * xx + dyy = ry.t() + ry - 2.0 * yy + dxy = rx.t() + ry - 2.0 * zz + + loss = 0.0 + coeff = -2.0 / B**2 + denom = B * (B - 1) + for r in self.kernel_radii: + loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dxx, r) - B) / denom + loss += coeff * torch.utils.checkpoint.checkpoint(self._exp_kernel, dxy, r) + loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dyy, r) - B) / denom + + loss = loss.clamp(min=0) + loss = self.loss_scale * loss + return loss + + +class MMDCodebookLoss(Loss): + """ + MMD loss which incentivizes independence between codebooks within each timestep. + + Args: + num_codebooks: Number of codebooks. + codebook_dim: Dimension of a single codebook code. + loss_fn: MMDLoss instance. + """ + + def __init__(self, num_codebooks, codebook_dim, loss_fn): + super().__init__() + self.num_codebooks = num_codebooks + self.codebook_dim = codebook_dim + self.loss_fn = loss_fn + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'D', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, inputs): + B, D, T = inputs.size() + + # [B, C, D / C, T] + x = inputs.reshape(B, self.num_codebooks, self.codebook_dim, T) + # [B*T, C, D / C] + x = rearrange(x, 'B C D T -> (B T) C D') + loss = self.loss_fn(inputs=x) + return loss + + +class MMDEmbeddingLoss(Loss): + """ + MMD loss which incentivizes independence between embedding values within each timestep. + + Args: + loss_fn: MMDLoss instance. + """ + + def __init__(self, loss_fn): + super().__init__() + self.loss_fn = loss_fn + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'D', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, inputs): + # [B*T, 1, D] + x = rearrange(inputs, 'B D T -> (B T) D 1') + loss = self.loss_fn(inputs=x) + return loss + + +class MMDTimeLoss(Loss): + """ + MMD loss which incentivizes independence between different timesteps. + + Args: + loss_fn: MMDLoss instance. + """ + + def __init__(self, loss_fn): + super().__init__() + self.loss_fn = loss_fn + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'D', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, inputs): + x = rearrange(inputs, 'B D T -> B T D') + loss = self.loss_fn(inputs=x) + return loss diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index 2dba794253a6..37b6a9a50aaf 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -17,7 +17,12 @@ from nemo.collections.tts.models.fastpitch import FastPitchModel from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL from nemo.collections.tts.models.hifigan import HifiGanModel -from nemo.collections.tts.models.magpietts import MagpieTTS_Model, MagpieTTS_ModelDPO, MagpieTTS_ModelInference +from nemo.collections.tts.models.magpietts import MagpieTTSModel +from nemo.collections.tts.models.magpietts_preference_optimization import ( + MagpieTTSModelOfflinePO, + MagpieTTSModelOfflinePODataGen, + MagpieTTSModelOnlinePO, +) from nemo.collections.tts.models.mixer_tts import MixerTTSModel from nemo.collections.tts.models.radtts import RadTTSModel from nemo.collections.tts.models.spectrogram_enhancer import SpectrogramEnhancerModel @@ -39,9 +44,10 @@ "MelPsuedoInverseModel", "MixerTTSModel", "RadTTSModel", - "MagpieTTS_Model", - "MagpieTTS_ModelInference", - "MagpieTTS_ModelDPO", + "MagpieTTSModel", + "MagpieTTSModelOfflinePODataGen", + "MagpieTTSModelOfflinePO", + "MagpieTTSModelOnlinePO", "Tacotron2Model", "TwoStagesModel", "UnivNetModel", diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 33b9a80125b7..4f4ca21edb30 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -32,7 +32,7 @@ SISDRLoss, TimeDomainLoss, ) -from nemo.collections.tts.modules.audio_codec_modules import ResNetSpeakerEncoder +from nemo.collections.tts.modules.audio_codec_modules import ResNetSpeakerEncoder, default_precision from nemo.collections.tts.modules.common import GaussianDropout from nemo.collections.tts.parts.utils.callbacks import LoggingCallback from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers @@ -42,7 +42,6 @@ from nemo.core.neural_types.neural_type import NeuralType from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler from nemo.utils import logging, model_utils -from nemo.utils.decorators import experimental try: import torchaudio @@ -52,7 +51,6 @@ HAVE_TORCHAUDIO = False -@experimental class AudioCodecModel(ModelPT): def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Convert to Hydra 1.0 compatible DictConfig @@ -143,6 +141,22 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.gen_loss_fn = instantiate(cfg.generator_loss) self.disc_loss_fn = instantiate(cfg.discriminator_loss) + self.mmd_loss_start_epoch = cfg.get("mmd_loss_start_epoch", 0) + + if "mmd_loss" in cfg: + self.mmd_loss_fn = instantiate(cfg.mmd_loss) + self.mmd_loss_scale = cfg.get("mmd_loss_scale", 1.0) + else: + self.mmd_loss_fn = None + self.mmd_loss_scale = None + + if "mmd_time_loss" in cfg: + self.mmd_time_loss_fn = instantiate(cfg.mmd_time_loss) + self.mmd_time_loss_scale = cfg.get("mmd_time_loss_scale", 1.0) + else: + self.mmd_time_loss_fn = None + self.mmd_time_loss_scale = None + feature_loss_type = cfg.get("feature_loss_type", "relative") if feature_loss_type == "relative": self.feature_loss_fn = RelativeFeatureMatchingLoss() @@ -191,6 +205,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.lr_schedule_interval = None self.automatic_optimization = False + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def num_codebooks(self): + if self.vector_quantizer is None: + raise ValueError("This AudioCodecModel does not have a vector quantizer.") + + return self.vector_quantizer.num_codebooks + + @property + def codebook_size(self): + if self.vector_quantizer is None: + raise ValueError("This AudioCodecModel does not have a vector quantizer.") + + return self.vector_quantizer.codebook_size + def state_dict(self, destination=None, prefix='', keep_vars=False): if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} @@ -307,7 +339,9 @@ def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Te raise ValueError("Cannot quantize without quantizer") # vector quantizer is returning [C, B, T], where C is the number of codebooks - tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) + with default_precision(torch.float32): + # vector quantizer is returning [C, B, T], where C is the number of codebooks + tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) # use batch first for the output tokens = rearrange(tokens, 'C B T -> B C T') return tokens @@ -336,7 +370,9 @@ def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Te # vector quantizer is using [C, B, T], where C is the number of codebooks tokens = rearrange(tokens, 'B C T -> C B T') - dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len) + with default_precision(torch.float32): + dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len) + dequantized = dequantized.to(self.dtype) # make sure dequantized is in the right dtype return dequantized @typecheck( @@ -389,6 +425,7 @@ def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch. """ # Convert a discrete representation to a dequantized vector for each frame dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len) + dequantized = dequantized.to(self.dtype) # make sure that the dequantized is in the model dtype # Apply decoder to obtain time-domain audio for each frame audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len) @@ -459,18 +496,22 @@ def _process_batch(self, batch): encoded = self.encoder_noise(encoded) if self.vector_quantizer: - if self.vector_quantizer_has_commit_loss: - encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) - else: - encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len) - commit_loss = 0.0 + with default_precision(torch.float32): + if self.vector_quantizer_has_commit_loss: + encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) + else: + encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len) + commit_loss = 0.0 + + encoded = encoded.to(encoded.dtype) # make sure encoded is converted to the right dtype else: commit_loss = 0.0 # [B, T] + encoded = encoded.to(self.dtype) # make sure vector quantizer output is in the model dtype audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len) - return audio, audio_len, audio_gen, commit_loss + return audio, audio_len, audio_gen, commit_loss, encoded @property def disc_update_prob(self) -> float: @@ -487,7 +528,7 @@ def should_update_disc(self, batch_idx) -> bool: def training_step(self, batch, batch_idx): optim_gen, optim_disc = self.optimizers() - audio, audio_len, audio_gen, commit_loss = self._process_batch(batch) + audio, audio_len, audio_gen, commit_loss, codes = self._process_batch(batch) metrics = { "global_step": self.global_step, @@ -508,7 +549,11 @@ def training_step(self, batch, batch_idx): generator_losses = [] - loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + # stft does not support bf16, so make it run in fp32 + loss_mel_l1, loss_mel_l2 = self.mel_loss_fn( + audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len + ) + if self.mel_loss_l1_scale: metrics["g_loss_mel_l1"] = loss_mel_l1 generator_losses.append(self.mel_loss_l1_scale * loss_mel_l1) @@ -517,7 +562,7 @@ def training_step(self, batch, batch_idx): generator_losses.append(self.mel_loss_l2_scale * loss_mel_l2) if self.stft_loss_scale: - loss_stft = self.stft_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) metrics["g_loss_stft"] = loss_stft generator_losses.append(self.stft_loss_scale * loss_stft) @@ -547,6 +592,19 @@ def training_step(self, batch, batch_idx): metrics["g_loss_commit"] = commit_loss generator_losses.append(self.commit_loss_scale * commit_loss) + if self.mmd_loss_scale: + loss_mmd = self.mmd_loss_fn(inputs=codes) + metrics["g_loss_mmd"] = loss_mmd + + if self.current_epoch >= self.mmd_loss_start_epoch: + generator_losses.append(self.mmd_loss_scale * loss_mmd) + + if self.mmd_time_loss_scale: + loss_mmd_time = self.mmd_time_loss_fn(inputs=codes) + metrics["g_loss_mmd_time"] = loss_mmd_time + if self.current_epoch >= self.mmd_loss_start_epoch: + generator_losses.append(self.mmd_time_loss_scale * loss_mmd_time) + # compute embeddings for speaker consistency loss if self.use_scl_loss: # concate generated and GT waveforms @@ -592,10 +650,12 @@ def on_train_epoch_end(self): self.update_lr("epoch") def validation_step(self, batch, batch_idx): - audio, audio_len, audio_gen, _ = self._process_batch(batch) + audio, audio_len, audio_gen, _, _ = self._process_batch(batch) - loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) - loss_stft = self.stft_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + loss_mel_l1, loss_mel_l2 = self.mel_loss_fn( + audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len + ) + loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) loss_si_sdr = self.si_sdr_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 9f0ae9c45d5f..cc6c84a234af 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -11,91 +11,65 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import json import os -import string -from typing import List +import random +import time +from functools import partial +from typing import Dict, List, Optional, Union -import librosa import numpy as np import soundfile as sf import torch +import wandb from hydra.utils import instantiate -from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from omegaconf import DictConfig, OmegaConf, open_dict from torch import nn from torch.utils.data import get_worker_info -from transformers import AutoTokenizer, T5Tokenizer -import nemo.collections.asr as nemo_asr -from nemo.collections.asr.metrics.wer import word_error_rate -from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 -from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths, plot_alignment_to_numpy -from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors +from nemo.collections.tts.modules.aligner import AlignmentEncoder +from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter +from nemo.collections.tts.modules.magpietts_modules import ( + CharAwareSubwordEncoder, + EOSDetectionMethod, + LocalTransformerType, + SpecialAudioToken, + cosine_schedule, +) +from nemo.collections.tts.parts.utils.helpers import ( + binarize_attention_parallel, + get_mask_from_lengths, + plot_alignment_to_numpy, +) from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging -def setup_tokenizers(all_tokenizers_config, use_text_conditioning_tokenizer, mode='train'): - # Being used in both model and worker_init_fn, so it is defined here - # Returns two tokenizers: one for TTS transcript and one for conditioning text (if needed) - tokenizers = [] - tokenizer_names = [] - for tokenizer_name in all_tokenizers_config: - tokenizer_config = all_tokenizers_config[tokenizer_name] - if tokenizer_config._target_ == 'AutoTokenizer': - tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.pretrained_model) - else: - text_tokenizer_kwargs = {} - if "g2p" in tokenizer_config: - text_tokenizer_kwargs["g2p"] = instantiate(tokenizer_config.g2p) - tokenizer = instantiate(tokenizer_config, **text_tokenizer_kwargs) - if mode == 'test' and hasattr(tokenizer, "set_phone_prob"): - tokenizer.set_phone_prob(1.0) - tokenizers.append(tokenizer) - tokenizer_names.append(tokenizer_name) - - aggregated_tokenizer = AggregatedTTSTokenizer(tokenizers, tokenizer_names) # TTS Transcript tokenizer - text_conditioning_tokenizer = None - - if use_text_conditioning_tokenizer: - # TODO: make this configurable - # Conditioning text tokenizer - text_conditioning_tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") - - return aggregated_tokenizer, text_conditioning_tokenizer - - def worker_init_fn(worker_id): # For mp.set_start_method("spawn", force=True) # The dataset class should be picklable, so we initialize non-picklable objects here logging.info(f"Worker {worker_id} initializing...") worker_info = get_worker_info() dataset = worker_info.dataset # Get the dataset instance in this worker - tokenizer, text_conditioning_tokenizer = setup_tokenizers( - dataset.tokenizer_config, dataset.use_text_conditioning_tokenizer, mode=dataset.dataset_type - ) + tokenizer = setup_tokenizers(dataset.tokenizer_config, mode=dataset.dataset_type) dataset.text_tokenizer = tokenizer - dataset.text_conditioning_tokenizer = text_conditioning_tokenizer -class MagpieTTS_Model(ModelPT): +class MagpieTTSModel(ModelPT): """ Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context audio/text Supports multiple model types: - - single_encoder_sv_tts: Transcript goes into the encoder and target audio goes to the decoder. Additionally, - speaker_embedding of target audio (or context audio if provided) from TitaNet gets added to encoder - output(all timesteps). - - multi_encoder_context_tts: Transcript and context audio go to different encoders. Transcript encoding feeds to layers given by cfg.model.transcript_decoder_layers and the context encoding feeds into the layers given by context_decoder_layers .Also supports text context which gets encoded by the same encoder as context audio. @@ -106,8 +80,8 @@ class MagpieTTS_Model(ModelPT): value (5 seconds). Text context, which is usually shorter than number of codec frames of 5 second of audio, is padded to the max context duration in this model. - - decoder_pretrain_synthesizer: This is the model type used for pretraining the decoder only on audio data using - next frame prediction loss. + - decoder_ce: Same as decoder_context_tts except there is a small neural network between the context tensors and + the decoder input. """ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): @@ -115,6 +89,67 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_devices + # load codec, disable loading of loss modules not needed during inference + codec_model_cfg = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), return_config=True) + if "use_scl_loss" in codec_model_cfg: + codec_model_cfg.use_scl_loss = False + codec_model = AudioCodecModel.restore_from( + cfg.get('codecmodel_path'), strict=False, override_config_path=codec_model_cfg + ) + self.sample_rate = codec_model.sample_rate + self.codec_model_samples_per_frame = codec_model.samples_per_frame + # del codec discriminator to free memory + del codec_model.discriminator + + # When using FSQ tokens, the codebook structure can be changed at any time. + # An FSQ definition can be provided in `vector_quantizer` config to train with a codebook structure + # that is different than in the audio codec checkpoint. + vector_quantizer = cfg.get('vector_quantizer') + if vector_quantizer is not None: + vector_quantizer = instantiate(vector_quantizer) + num_audio_codebooks = vector_quantizer.num_codebooks + codebook_size = vector_quantizer.codebook_size + codec_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=vector_quantizer, + ) + data_num_audio_codebooks = codec_model.vector_quantizer.num_codebooks + else: + num_audio_codebooks = codec_model.num_codebooks + data_num_audio_codebooks = num_audio_codebooks + codebook_size = codec_model.codebook_size + codec_converter = None + # The dataloader needs to know the number of codebooks that the context codes were stored in + # In the case where there are no context codes saved, and there is no context audio (in the text context path), + # We create a dummy context code tensor that is only [context_BOS, context_EOS] that is repeated for + # data_num_audio_codebooks + self.data_num_audio_codebooks = data_num_audio_codebooks + self.num_audio_codebooks = num_audio_codebooks + self.codebook_size = codebook_size + + # Our codebooks start with actual audio codec tokens, followed by special tokens. + # The `forced_*` options are for backward compatibility for models trained with older code. + get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=self.codebook_size) + self.audio_bos_id = cfg.get('forced_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_BOS)) + self.audio_eos_id = cfg.get('forced_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_EOS)) + self.context_audio_bos_id = cfg.get( + 'forced_context_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS) + ) + self.context_audio_eos_id = cfg.get( + 'forced_context_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS) + ) + self.mask_token_id = cfg.get('forced_mask_token_id', get_token_index(SpecialAudioToken.MASK_TOKEN)) + self.num_all_tokens_per_codebook = cfg.get( + 'forced_num_all_tokens_per_codebook', self.codebook_size + len(SpecialAudioToken) + ) + self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False) + + # The frame stacking factor controls how many consecutive frames are processed together by the base decoder + # (and then refined into individual frames by the local transformer). A frame stacking factor of 1 means no + # frame stacking. We have a separate embedding table for each of the stacked frames, e.g. for frame stacking + # factor of 3, the entries of codebook 0 appear 3 times in the embedding table. + self.frame_stacking_factor = cfg.get('frame_stacking_factor', 1) + assert 'downsample_factor' not in cfg, '`downsample_factor` is deprecated, use `frame_stacking_factor` instead' # Setup tokenizer if hasattr(cfg, 'text_tokenizer'): # For backward compatibility for English-only models @@ -123,72 +158,152 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): del cfg['text_tokenizer'] self.use_text_conditioning_encoder = cfg.get('use_text_conditioning_encoder', False) - tokenizer, text_conditioning_tokenizer = self._setup_tokenizers(cfg) - self.tokenizer = tokenizer - self.text_conditioning_tokenizer = text_conditioning_tokenizer + # Using google-t5/t5-small as default text conditioning tokenizer for backward compatibility. + self.text_conditioning_tokenizer_name = cfg.get('text_conditioning_tokenizer_name', None) + self.legacy_text_conditioning = cfg.get('legacy_text_conditioning', False) + + if self.legacy_text_conditioning: + if self.text_conditioning_tokenizer_name is None: + self.text_conditioning_tokenizer_name = "google-t5/t5-small" + + tokenizer_target = "AutoTokenizer" + if self.text_conditioning_tokenizer_name == "google-t5/t5-small": + tokenizer_target = "T5Tokenizer" + + with open_dict(cfg): + cfg.text_tokenizers[self.text_conditioning_tokenizer_name] = { + '_target_': tokenizer_target, + 'pretrained_model': self.text_conditioning_tokenizer_name, + } + elif self.text_conditioning_tokenizer_name is None: + # If no text_conditioning_tokenizer_name is specified, use the first one as default + # For text context tokenization + self.text_conditioning_tokenizer_name = list(cfg.text_tokenizers.keys())[0] + + # TODO @xueyang: both tokenizers are only used to get some token ids. We + # should kill them to save a small amount of mem resources since dataloader will initialize them + # again after the worker processes are spawned. + self.tokenizer = setup_tokenizers( + all_tokenizers_config=cfg.text_tokenizers, + mode='train', + ) num_tokens_tokenizer = len(self.tokenizer.tokens) + if self.legacy_text_conditioning: + # Text context tokens are not a part of the the regular transcript embedding table in legacy models + num_tokens_tokenizer -= self.tokenizer.num_tokens_per_tokenizer[self.text_conditioning_tokenizer_name] + num_tokens = num_tokens_tokenizer + 2 # +2 for BOS and EOS self.bos_id = num_tokens - 2 self.eos_id = num_tokens - 1 - self.audio_bos_id = cfg.num_audio_tokens_per_codebook - 2 - self.audio_eos_id = cfg.num_audio_tokens_per_codebook - 1 - self.context_audio_bos_id = cfg.num_audio_tokens_per_codebook - 2 # For backward compatibility - self.context_audio_eos_id = cfg.num_audio_tokens_per_codebook - 1 # For backward compatibility - self.model_type = cfg.get('model_type', 'single_encoder_sv_tts') + self.model_type = cfg.get('model_type', None) + self.pad_context_text_to_max_duration = self.model_type in ['decoder_context_tts', 'decoder_ce'] + self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False) - if self.model_type == 'decoder_context_tts': - self.context_audio_bos_id = ( - cfg.num_audio_tokens_per_codebook - 4 - ) # Changing these to make them different from target audio bos and eos - self.context_audio_eos_id = cfg.num_audio_tokens_per_codebook - 3 + # Below args (text_context_remapping_json, text_context_remapping_prob) are + # for combining multiple context_texts into a single one during training. + # Eg. if we want to treat Emma_neutral and Emma_conversational as one speaker, + # we can create an override dict {'Emma_neutral' : 'Emma', 'Emma_conversational' : 'Emma'} + # This dict is saved in a json file given by cfg.model.text_context_remapping_json + # If we want to preserve both behaviours i.e (Emma_neutral, Emma_conversational) and just (Emma) + # we can do this mapping with a probability during training, as specified by text_context_remapping_prob + self.text_context_remapping = None + text_context_remapping_json = cfg.get('text_context_remapping_json', None) + self.text_context_remapping_prob = cfg.get('text_context_remapping_prob', 0.0) + if text_context_remapping_json is not None: + with open(text_context_remapping_json, 'r') as f: + self.text_context_remapping = json.load(f) - self._tb_logger = None + super().__init__(cfg=cfg, trainer=trainer) - self.pad_context_text_to_max_duration = self.model_type == 'decoder_context_tts' - self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False) + if self.legacy_text_conditioning: + tc_tokenizer = self.tokenizer.tokenizers[self.text_conditioning_tokenizer_name] + self.context_text_embedding = nn.Embedding(tc_tokenizer.vocab_size, cfg.embedding_dim) - super().__init__(cfg=cfg, trainer=trainer) + # This needs to happen after super().__init__() + self._codec_model = codec_model + self._codec_model.freeze() # Lightning does requires_grad = False and self.eval() + self._codec_converter = codec_converter audio_embeddings = [] - for _ in range(cfg.num_audio_codebooks): - audio_embeddings.append(nn.Embedding(cfg.num_audio_tokens_per_codebook, cfg.embedding_dim)) + for _ in range(self.num_audio_codebooks * self.frame_stacking_factor): + audio_embeddings.append(nn.Embedding(self.num_all_tokens_per_codebook, cfg.embedding_dim)) self.audio_embeddings = nn.ModuleList(audio_embeddings) - if self.model_type != 'decoder_pretrain_synthesizer': - # Decoder pretrain synthesizer doesn't have transcript encoder/text embeddings + if self.use_bpe_char_tokenizer: + # BPE char tokenizer + assert len(self.tokenizer.tokenizers) == 1, "BPE char tokenizer should only be used with one tokenizer" + tokenizer_name = self.tokenizer.tokenizer_names[0] + tokenizer = self.tokenizer.tokenizers[tokenizer_name] + subword_vocab = tokenizer.get_vocab() + # special tokens will be stored as it is in the char_vocab + # Each special token will only be mapped to one char id + special_vocab = { + '': self.bos_id, + '': self.eos_id, + } + self.cas_encoder = CharAwareSubwordEncoder( + d_embed=cfg.embedding_dim, + llm_tokenizer_vocab=subword_vocab, + subword_padding_idx=self.tokenizer.pad, + special_vocab=special_vocab, + ) + else: + # Regular text embedding self.text_embedding = nn.Embedding(num_tokens, cfg.embedding_dim) - self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) + self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) + self.final_proj = nn.Linear( + cfg.decoder.d_model, + self.num_audio_codebooks * self.num_all_tokens_per_codebook * self.frame_stacking_factor, + ) - self.final_proj = nn.Linear(cfg.decoder.d_model, cfg.num_audio_codebooks * cfg.num_audio_tokens_per_codebook) + self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower()) + logging.info(f"Local transformer type: {self.local_transformer_type}") + if self.local_transformer_type != LocalTransformerType.NO_LT: + local_transformer_hidden_dim = cfg.get('local_transformer_hidden_dim', 256) + if local_transformer_hidden_dim != cfg.decoder.d_model: + self.local_transformer_in_projection = nn.Linear(cfg.decoder.d_model, local_transformer_hidden_dim) + else: + self.local_transformer_in_projection = nn.Identity() + self.local_transformer = transformer_2501.Transformer( + n_layers=self.cfg.get('local_transformer_n_layers', 2), + d_model=local_transformer_hidden_dim, + d_ffn=local_transformer_hidden_dim * 4, + sa_n_heads=self.cfg.get('local_transformer_n_heads', 1), + kernel_size=1, + is_causal=self.local_transformer_type == LocalTransformerType.AR, + max_length_causal_mask=self.frame_stacking_factor * self.num_audio_codebooks + 2, + use_learnable_pos_emb=True, + ) + local_transformer_out_projections = [] + for _ in range(self.num_audio_codebooks * self.frame_stacking_factor): + # Have a separate projection layer for each codebook, to distinguish between them + local_transformer_out_projections.append( + nn.Linear(local_transformer_hidden_dim, self.num_all_tokens_per_codebook) + ) + self.local_transformer_out_projections = nn.ModuleList(local_transformer_out_projections) + + if cfg.get('use_alignment_encoder', False): + self.alignment_encoder = AlignmentEncoder( + n_mel_channels=cfg.embedding_dim, + n_text_channels=cfg.embedding_dim, + dist_type="cosine", + temperature=15.0, + ) - codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False) - # del codec discriminator to free memory - del codec_model.discriminator - codec_model.eval() - self.freeze_model(codec_model) - self._codec_model = codec_model + if self.model_type == 'multi_encoder_context_tts': + logging.warning(f"The multi_encoder_context_tts model type for {self} is deprecated.") - if self.model_type == 'single_encoder_sv_tts': - speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( - model_name='titanet_large' - ) - speaker_verification_model.eval() - self.freeze_model(speaker_verification_model) - self._speaker_verification_model = speaker_verification_model - self.speaker_projection_layer = nn.Linear(cfg.speaker_emb_dim, cfg.embedding_dim) - self.transcript_decoder_layers = [ - idx for idx in range(cfg.decoder.n_layers) - ] # All layers are used for text - elif self.model_type == 'multi_encoder_context_tts': + # Transcript and context audio/text go to different encoders. + # Output of the encoders goes to the decoder through the cross-attention layers self.transcript_decoder_layers = cfg.get('transcript_decoder_layers', [3, 4, 5, 6, 7, 8]) self.context_decoder_layers = cfg.get( 'context_decoder_layers', [0, 1, 2, 9, 10, 11] ) # For backward compatibility - multi_encoder_mapping = [None for _ in range(cfg.decoder.n_layers)] + multi_encoder_mapping = [None for _ in range(self.decoder.n_layers)] for layer in self.transcript_decoder_layers: multi_encoder_mapping[layer] = 0 # 0 means text goes to this layer, 1 means context goes to this layer for layer in self.context_decoder_layers: @@ -196,27 +311,61 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.multi_encoder_mapping = multi_encoder_mapping self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder)) elif self.model_type == 'decoder_context_tts': + # Context audio/text goes directly to the decoder (before the target audio codes) + self.transcript_decoder_layers = [ + idx for idx in range(self.decoder.n_layers) + ] # All layers are used for text + elif self.model_type == 'decoder_ce': + # Similar to decoder_context_tts, but we use context encoder + # Decoder gets output from context encoder instead of raw context tokens embeddings + self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder)) self.transcript_decoder_layers = [ idx for idx in range(cfg.decoder.n_layers) ] # All layers are used for text - elif self.model_type == 'decoder_pretrain_synthesizer': - assert cfg.alignment_loss_scale == 0.0, "Alignment loss is not supported for decoder pretrain synthesizer" else: raise ValueError(f"Unsupported model type {self.model_type}") - if self.use_text_conditioning_encoder: - self.context_text_embedding = nn.Embedding(self.text_conditioning_tokenizer.vocab_size, cfg.embedding_dim) - self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') - alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0) - if alignment_loss_scale > 0.0: - self.alignment_loss = ForwardSumLoss(loss_scale=alignment_loss_scale) - - def freeze_model(self, model): - for param in model.parameters(): - param.requires_grad = False + self.alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0) + self.alignment_encoder_loss_scale = cfg.get('alignment_encoder_loss_scale', 0.0) + if self.alignment_loss_scale > 0.0: + self.alignment_loss = ForwardSumLoss(loss_scale=self.alignment_loss_scale) + if self.alignment_encoder_loss_scale > 0.0: + self.alignment_encoder_loss = ForwardSumLoss(loss_scale=self.alignment_encoder_loss_scale) + + # Define cfg parameters into self parameters + self.prior_end_step = self.cfg.prior_end_step + self.prior_scaledown_start_step = self.cfg.prior_scaledown_start_step + self.indefinite_prior_prob = self.cfg.get('indefinite_prior_prob', 0.0) + self.ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers) + self.cfg_unconditional_prob = self.cfg.get('cfg_unconditional_prob', 0.0) + self.decoder_input_dropout_prob = self.cfg.get('decoder_input_dropout_prob', 0.0) + self.binarize_attn_method = self.cfg.get('binarize_attn_method', 'argmax') + self.binarize_repeat_audio_factor = self.cfg.get('binarize_repeat_audio_factor', 2) + self.prior_future_decay = self.cfg.get('prior_future_decay', 1.0) + self.prior_past_decay = self.cfg.get('prior_past_decay', 1.0) + self.binarized_prior_epsilon = self.cfg.get('binarized_prior_epsilon', 0.0) + self.prior_future_context = self.cfg.get('prior_future_context', 1) + self.prior_past_context = self.cfg.get('prior_past_context', 1) + self.binarize_prior_after_step = self.cfg.get('binarize_prior_after_step', 0) + self.codebook_loss_scale = self.cfg.get('codebook_loss_scale', 1.0) + self.local_transformer_loss_scale = self.cfg.get('local_transformer_loss_scale', 1.0) + self.use_alignment_encoder = self.cfg.get('use_alignment_encoder', False) + self.use_prior_for_aligner = self.cfg.get('use_prior_for_aligner', False) + self.aligner_encoder_train_steps = self.cfg.get('aligner_encoder_train_steps', float('inf')) + self.dec_random_input_max = self.cfg.get('dec_random_input_max', self.num_all_tokens_per_codebook) + + # Configuration validity checks + self.check_frame_stacking_config_validity() def state_dict(self, destination=None, prefix='', keep_vars=False): + """ + Only used for saving checkpoints. On save, we remove _speaker_verification_model and _codec_model + from the checkpoint. The codec model is saved in a separate checkpoint. + + _speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts + model_type that is no longer supported and can likely be removed in a future version. + """ if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} # Don't save the speaker verification and codec model in the state dict @@ -227,28 +376,77 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): del state_dict[key] return state_dict - def load_state_dict(self, state_dict, strict=True): - # Override to load all the keys except _speaker_verification_model and _codec_model - super().load_state_dict(state_dict, strict=False) + def check_frame_stacking_config_validity(self): + """ + Check if the configuration is compatible with frame stacking. + """ + if self.frame_stacking_factor > 1: + # The settings below are not supported with frame stacking. + # Some of them may work - but they have not been tested. + + # disallow alignment encoder + if self.use_alignment_encoder: + raise ValueError("Alignment encoder is not supported for frame stacking") + # disallow alignment loss + if self.alignment_loss_scale > 0.0: + raise ValueError("Alignment loss is not supported for frame stacking") + # disallow training prior + if self.cfg.prior_scaling_factor is not None and self.cfg.prior_scaling_factor > 0: + raise ValueError("Training-time attention prior is not supported for frame stacking") + # disallow text conditioning + if self.use_text_conditioning_encoder: + raise ValueError("Text conditioning is not supported for frame stacking") - def _setup_tokenizers(self, cfg, mode='test'): - tokenizer, text_conditioning_tokenizer = setup_tokenizers( - cfg.text_tokenizers, cfg.use_text_conditioning_encoder, mode=mode - ) - return tokenizer, text_conditioning_tokenizer + def update_ckpt(self, state_dict): + """ + Backward compatibility for checkpoints saved with old model names. + """ + new_state_dict = {} + for key in state_dict.keys(): + if 't5_encoder' in key: + new_key = key.replace('t5_encoder', 'encoder') + new_state_dict[new_key] = state_dict[key] + elif 't5_decoder' in key: + new_key = key.replace('t5_decoder', 'decoder') + new_state_dict[new_key] = state_dict[key] + else: + new_state_dict[key] = state_dict[key] + return new_state_dict - @property - def tb_logger(self): - if self._tb_logger is None: - if self.logger is None and self.logger.experiment is None: - return None - tb_logger = self.logger.experiment - for logger in self.trainer.loggers: - if isinstance(logger, TensorBoardLogger): - tb_logger = logger.experiment - break - self._tb_logger = tb_logger - return self._tb_logger + def load_state_dict(self, state_dict, strict=True): + """ + Modify load_state_dict so that we don't restore weights to _speaker_verification_model and _codec_model when + strict is True. + When strict is False, we can call pytorch's load_state_dict. + When strict is True, we loop through all parameters and rename them to enable loading. + + _speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts + model_type that is no longer supported and can likely be removed in a future version. + """ + state_dict = self.update_ckpt(state_dict) + if strict == False: + super().load_state_dict(state_dict, strict=False) + for name, child in self.named_children(): + if name in [ + '_speaker_verification_model', + '_codec_model', + '_reference_model', + 'eval_asr_model', + 'eval_speaker_verification_model', + 'whisper_model', + 'squim_objective_model', + ]: + continue + if any(param.numel() > 0 for param in child.parameters()): + # If the module has parameters, we want to change the default mapping so that the state_dict gets + # loaded. + # Ex: state_dict[encoder.position_embeddings.weight] -> new_state_dict[position_embeddings.weight] + new_state_dict = {} + for key in state_dict.keys(): + name_with_dot = f"{name}." + if key.startswith(name_with_dot): + new_state_dict[key[len(name_with_dot) :]] = state_dict[key] + child.load_state_dict(new_state_dict) def audio_to_codes(self, audio, audio_len, audio_type='target'): # audio: (B, T) @@ -263,12 +461,16 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): raise ValueError(f"Received audio_type of {audio_type}. Must be `target` or `context`") self._codec_model.eval() - with torch.no_grad(): + with torch.no_grad(), torch.autocast(device_type=audio.device.type, dtype=torch.float32): codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len) + if self._codec_converter is not None: + codes = self._codec_converter.convert_original_to_new(audio_tokens=codes, audio_lens=codes_len) # Add a timestep to begining and end of codes tensor bos_tensor = torch.full( (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device ) + # pad at the end to make room for the EOS token; the EOS token's actual position + # varies per batch element depending on each element's length. pad_tensor = torch.full( (codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device ) # 0 is the padding token in the audio codebook @@ -277,69 +479,200 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'): # codes_len: (B,) for idx in range(codes.size(0)): codes[idx, :, codes_len[idx] + 1] = audio_eos_id - codes_len = codes_len + 2 - + codes_len = codes_len + 2 # +1 for bos and +1 for eos return codes.long(), codes_len.long() def codes_to_audio(self, codes, codes_len): # codes: (B, C, T') # codes_len: (B,) self._codec_model.eval() - with torch.no_grad(): - # Replace eos and bos tokens with padding in codes tensor - codes[codes == self.audio_bos_id] = 0 # zero is the padding token in the audio codebook - codes[codes == self.audio_eos_id] = 0 - # self.additional_models['codec'] = self.additional_models['codec'].to(codes.device) - audio, audio_len = self._codec_model.decode(tokens=codes, tokens_len=codes_len) + with torch.no_grad(), torch.autocast(device_type=codes.device.type, dtype=torch.float32): + # Make a copy to avoid modifying the original tensor if it's used elsewhere + codes_copy = codes.clone() + # Replace eos and bos tokens with padding in the copied tensor + codes_copy[codes == self.audio_bos_id] = 0 # zero is the padding token + codes_copy[codes == self.audio_eos_id] = 0 + # Pass the modified integer token IDs + if self._codec_converter is not None: + codes_copy = self._codec_converter.convert_new_to_original( + audio_tokens=codes_copy, audio_lens=codes_len + ) + audio, audio_len = self._codec_model.decode(tokens=codes_copy, tokens_len=codes_len) # audio: (B, T) # audio_len: (B,) return audio, audio_len def embed_audio_tokens(self, audio_tokens): - # audio_tokens: (B, C, T') - # Add and average the embeddings of the audio tokens across the codebooks + B, C, T = audio_tokens.shape audio_embedding = None - for c in range(audio_tokens.size(1)): - embedding = self.audio_embeddings[c](audio_tokens[:, c, :]) - if audio_embedding is None: - audio_embedding = embedding - else: - audio_embedding = audio_embedding + embedding - audio_embedding = audio_embedding / audio_tokens.size(1) + for i in range(self.frame_stacking_factor): + for c in range(C): + tokens = audio_tokens[:, c, i :: self.frame_stacking_factor] + embedding = self.audio_embeddings[c + i * C](tokens) + if audio_embedding is None: + audio_embedding = embedding + else: + audio_embedding += embedding + audio_embedding = audio_embedding / (C * self.frame_stacking_factor) return audio_embedding - def get_speaker_embeddings(self, audio_16khz, audio_len_16khz): - # audio_16khz: (B, T) - # audio_len_16khz: (B,) - self._speaker_verification_model.eval() - with torch.no_grad(): - _, speaker_embeddings = self._speaker_verification_model.forward( - input_signal=audio_16khz, input_signal_length=audio_len_16khz - ) - return speaker_embeddings - - def compute_loss(self, logits, audio_codes, audio_codes_lens): - # logits: (B, T', num_codebooks * num_tokens_per_codebook) - # audio_codes: (B, C, T') - # audio_codes_lens: (B,) - loss_mask = get_mask_from_lengths(audio_codes_lens) + def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_offset_by_one=False): + """ + Predicts the logits for all codebooks using the local transformer. Used in both autoregressive (AR) and MaskGit (MG) modes. + This function is used in training and validation, not inference/sampling. + The sequence layout is slightly different between AR and MG modes, as shown in the diagram below, + (using an 8-codebook setup as an example): + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + | AR target | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | none | + | codebook | | | | | | | | | | + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + | MG target | none | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + | codebook | | | | | | | | | | + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + | input | Magpie | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + | codebook | latent | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + | seq. index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | + +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ + + dec_out: (B, T', E) + audio_codes_target: (B, C, T') + targets_offset_by_one: bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive) + if True, the target for index 1 is codebook 0, for index 2 is codebook 1, etc. (MaskGit) + """ + C = self.num_audio_codebooks + dec_out_all = dec_out.reshape(-1, dec_out.size(-1)) # (B*T', E) + local_transformer_input = [dec_out_all] + # Build the teacher-forced input to the LT. + for fs_index in range(self.frame_stacking_factor): + for codebook_num in range(C): + # Collect ground truth codes for the current codebook and frame stack index combintation. + codes = audio_codes_target[:, codebook_num, fs_index :: self.frame_stacking_factor] # (B, T') + # Individual timesteps are independently handled by the LT fold time into the batch dimension. + codes = codes.reshape(-1) # (B*T',) + # Embed the codes + codebook_embedding = self.audio_embeddings[codebook_num + fs_index * C](codes) # (B*T', E) + local_transformer_input.append(codebook_embedding) + # Stack the input codes along dimension 1 (codebooks). This is the dimension along which the LT predicts iteratively. + local_transformer_input = torch.stack(local_transformer_input, dim=1) # (B*T', C+1, E) + local_transformer_input = self.local_transformer_in_projection(local_transformer_input) # (B*T', C+1, 128) + _mask = torch.ones( + local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device + ) + local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B*T', C+1, E) + if not targets_offset_by_one: + # for autoregressive local transformer the target for index 0 is codebook 0, for index 1 is codebook 1, etc. + local_transformer_output = local_transformer_output[:, :-1, :] # (B*T', C, E) + else: + # for MaskGit the target for index **1** is codebook 0, for index 2 is codebook 1, etc. + local_transformer_output = local_transformer_output[:, 1:, :] # (B*T', C, E) + all_code_logits = [] + for fs_index in range(self.frame_stacking_factor): + for codebook_num in range(audio_codes_target.size(1)): + # Using a separate projection layer for each codebook (to distinguish between them) + # Checked the time - this loop is not taking much time (compared to the local transformer forward pass) + codebook_logits = self.local_transformer_out_projections[codebook_num + fs_index * C]( + local_transformer_output[:, codebook_num + fs_index * C, :] + ) # (B*T', num_all_tokens_per_codebook) + all_code_logits.append(codebook_logits) + all_code_logits = torch.cat( + all_code_logits, dim=1 + ) # (B*T'/frame_stacking_factor, num_codebooks * num_all_tokens_per_codebook * frame_stacking_factor) + + all_code_logits = all_code_logits.view( + audio_codes_target.size(0), audio_codes_target.size(2) // self.frame_stacking_factor, -1 + ) # (B, T'/frame_stacking_factor, C * num_all_tokens_per_codebook * frame_stacking_factor) + + return all_code_logits + + def maskgit_create_random_mask(self, codes): + """ + Creates a mask where True indicates the positions that should be replaced with a MASK_TOKEN. + """ + # Codes: (B, C, T) + B, C, T = codes.shape + # get a uniform random vector uniformly sampled from [0,1) ## Todo does it need to be inclusive on the right? + rand_values = torch.rand(B, T, device=codes.device) + # apply the cosine schedule + frac_masked = cosine_schedule(rand_values) + # how many positions to mask + n_masked = torch.ceil(frac_masked * C).long() # B,T + # The code further below is the vectorized version of this: + # for b in range(B): + # for t in range(T): + # if n_masked[b,t] > 0: + # # get a random permutation of the codebook indices + # perm = torch.randperm(C) + # # mask the top n_masked positions + # mask[b, perm[:n_masked[b,t]], t] = True + # + # Create random permutations + random_permutations = torch.argsort(torch.rand(B, C, T, device=codes.device), dim=1) # (B, C, T) + # Create a mask tensor where each position indicates if it should be masked + mask_indices = torch.arange(C, device=codes.device).view(1, C, 1) + mask = mask_indices < n_masked.view(B, 1, T) # (B, C, T) + # Apply the random permutations to the mask + mask = torch.gather(mask, 1, random_permutations) + + return mask # (B, C, T) + + def maskgit_apply_random_mask(self, codes): + # Randomly replaces some codes with the MASK_TOKEN with a proportion following the cosine schedule. + # Codes: (B, C, T) + mask = self.maskgit_create_random_mask(codes) + # replace some tokens with MASK_TOKEN + codes_with_mask = torch.where(mask, self.mask_token_id, codes) + return codes_with_mask, mask + + def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=None, frame_stacking_factor=1): + """ + Computes the audio codebook loss. Used by + (1) The main Magpie-TTS transformer + (2) The local transformer, for both autoregressive and MaskGit methods + + logits: (B, T', num_codebooks * num_tokens_per_codebook) + audio_codes: (B, C, T') + audio_codes_lens: (B,) + mask_tokens_mask: (B, C, T') True for tokens that were replaced with the MASK_TOKEN and should + therefore be the only ones included in the loss computation (for MaskGit). + frame_stacking_factor: int, the stacking factor used in the model + """ + loss_mask = get_mask_from_lengths(audio_codes_lens, pad_to_factor=frame_stacking_factor) + if mask_tokens_mask is not None: + # For MaskGit we only compute loss for the masked tokens. + # *Both* conditions must be true: + # 1. the token is masked + # 2. the token is not padding + loss_mask = loss_mask.unsqueeze(1) * mask_tokens_mask + if not loss_mask.any(): + # Without this we were very rarely getting NaNs in the loss + logging.warning("No tokens valid were found in compute_loss()!") + return torch.tensor(0.0, device=loss_mask.device), loss_mask + else: + # repeat loss mask for each codebook to simplify code below + loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1) total_codebook_loss = None - for codebook in range(audio_codes.size(1)): - si = codebook * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook - codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook) - codebook_targets = audio_codes[:, codebook] # (B, T') - codebook_loss = self.cross_entropy_loss( - codebook_logits.permute(0, 2, 1), codebook_targets # (B, num_tokens_per_codebook, T') - ) # (B, T') - codebook_loss = codebook_loss * loss_mask - codebook_loss = codebook_loss.sum() / loss_mask.sum() - if total_codebook_loss is None: - total_codebook_loss = codebook_loss - else: - total_codebook_loss = total_codebook_loss + codebook_loss + for fs_index in range(frame_stacking_factor): + for codebook in range(audio_codes.size(1)): + si = (codebook + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook + codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook) + codebook_targets = audio_codes[:, codebook, fs_index::frame_stacking_factor] # (B, T') + codebook_loss = self.cross_entropy_loss( + codebook_logits.permute(0, 2, 1), codebook_targets # (B, num_tokens_per_codebook, T') + ) # (B, T') + codebook_loss_mask = loss_mask[:, codebook, fs_index::frame_stacking_factor] + codebook_loss = codebook_loss * codebook_loss_mask + if codebook_loss_mask.sum() == 0: + logging.warning(f"Loss mask for codebook {codebook} is all zeros, global_step: {self.global_step}") + continue + codebook_loss = codebook_loss.sum() / codebook_loss_mask.sum() + if total_codebook_loss is None: + total_codebook_loss = codebook_loss + else: + total_codebook_loss = total_codebook_loss + codebook_loss - total_codebook_loss = total_codebook_loss / audio_codes.size(1) + total_codebook_loss = total_codebook_loss / (audio_codes.size(1) * frame_stacking_factor) return total_codebook_loss, loss_mask def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prior, multi_encoder_mapping): @@ -353,66 +686,542 @@ def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prio ) attn_probabilities = decoder_out['attn_probabilities'] all_code_logits = self.final_proj(decoder_out['output']) # (B, T', num_codebooks * num_tokens_per_codebook) - return all_code_logits, attn_probabilities + return all_code_logits, attn_probabilities, decoder_out['output'] def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): # all_code_logits: (B, T', num_codebooks * num_tokens_per_codebook) # audio_codes_lens: (B,) - all_preds = [] - for idx in range(self.cfg.num_audio_codebooks): - si = idx * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook - codebook_logits = all_code_logits[:, :, si:ei] - codebook_probs = torch.softmax(codebook_logits, dim=-1) # (B, T', num_tokens_per_codebook) - # argmax to get the tokens - codebook_preds = torch.argmax(codebook_probs, dim=-1) # (B, T') - all_preds.append(codebook_preds) - - all_preds = torch.stack(all_preds, dim=1) # (B, C, T') + all_preds = [[] for _ in range(self.frame_stacking_factor)] + for fs_index in range(self.frame_stacking_factor): + for idx in range(self.num_audio_codebooks): + si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook + codebook_logits = all_code_logits[:, :, si:ei] + codebook_probs = torch.softmax(codebook_logits, dim=-1) # (B, T', num_tokens_per_codebook) + # argmax to get the tokens + codebook_preds = torch.argmax(codebook_probs, dim=-1) # (B, T') + all_preds[fs_index].append(codebook_preds) + all_preds = [ + torch.stack(p, dim=1) for p in all_preds + ] # list of `frame_stacking_factor`` elements of shape (B,C,T) each + all_preds = torch.stack(all_preds, dim=-1) # B, C, T, frame_stacking_factor + # undo the frame stacking + all_preds = all_preds.reshape(all_preds.size(0), all_preds.size(1), -1) # B, C, T*frame_stacking_factor + pred_max_len = all_preds.size(2) + real_max_len = audio_codes_lens.max() + assert (pred_max_len - real_max_len) < self.frame_stacking_factor + # trim padding introduced for frame stacking + all_preds = all_preds[:, :, :real_max_len] audio_mask = get_mask_from_lengths(audio_codes_lens) all_preds = all_preds * audio_mask.unsqueeze(1) return all_preds - def sample_codes_from_logits(self, all_code_logits_t, temperature=0.7, topk=80): - # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep + def visualize_codes(self, codes, mask_id=2020, frame_stacking_rate=2): + """ + Visualize codes for analysis purposes + codes: (B, C) + """ + + def code_to_str(code): + if code == mask_id: + return "M " + else: + return f"{code:04d} " + + B, C = codes.shape + if B > 1: + logging.debug("Warning: visualizing only first batch element") + codes = codes.clone().detach().cpu().numpy()[0] + codes = [code_to_str(c) for c in codes] + output_str = "" + for i, c in enumerate(codes): + if (i) % (C / frame_stacking_rate) == 0: + output_str += "|timestep| " + output_str += c + logging.debug(output_str) + + def clear_forbidden_logits(self, logits: torch.Tensor, forbid_audio_eos: bool = False) -> torch.Tensor: + """ + Sets logits of forbidden tokens to `-inf` so they will never be sampled. + Specifically, we forbid sampling of all special tokens except AUDIO_EOS + which is allowed by default. + Args: + logits: (B, C, num_audio_tokens_per_codebook) + forbid_audio_eos (bool, optional): If True, also forbid AUDIO_EOS tokens + from being sampled. Default: False. + """ + logits[ + :, + :, + SpecialAudioToken.get_forbidden_tokens(self.codebook_size, forbid_audio_eos=forbid_audio_eos), + ] = float('-inf') + return logits + + def local_transformer_sample_maskgit( + self, + dec_output: torch.Tensor, + temperature: float = 0.7, + topk: int = 80, + unfinished_items: Dict[int, bool] = {}, + finished_items: Dict[int, bool] = {}, + use_cfg: bool = False, + cfg_scale: float = 1.0, + n_steps: int = 3, + noise_scale: float = 0.0, + fixed_schedule: Optional[List[int]] = None, + dynamic_cfg_scale: bool = False, + sampling_type: Optional[str] = None, + forbid_audio_eos: bool = False, + ) -> torch.Tensor: + """ + Sample audio codes for the current timestep using MaskGit-like iterative + prediction with the local transformer. If frame-stacking is enabled, the + codes for all frames in the stack are sampled, treated as one long sequence. + + The MaskGit process starts with all positions masked and iteratively unmasks the + most confident positions over multiple steps. By "masked" we mean that a + dedicated MASK token is used (as opposed to attention masking). The LT in this + case is a non-causal transformer decoder. At each step the model predicts all + positions at once. Of those predictions, a subset of the most confident + previously-masked positions is kept and unmasked in the next step. The number of + positions that are unmasked at each step is determined by the unmasking + schedule. We support a cosine schedule and a fixed schedule provided by the + user. + + Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG). + + Special handling: + * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled + * forces / forbids EOS for finished / unfinished items respectively + * optionally, globally forbids audio EOS for all items in the batch. + This is useful early in the generation process. + * supports different unmasking methods, see `sampling_type` argument for details. + + Args: + dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size + and E is primary decoder's embedding dimension. + temperature (float, optional): Sampling temperature + topk (int, optional): Number of top-probability tokens to consider in sampling. + unfinished_items (dict, optional): Dictionary containing indices of batch + items that we are confident have not completed generation. For these items, audio EOS + sampling is forbidden. + finished_items (dict, optional): Dictionary containing indices of batch + items that we are confident are completed. For these items, audio EOS sampling + is forced. + use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size + to be doubled with conditional and unconditional outputs from the primary decoder. + cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True. + n_steps (int, optional): Number of iterative refinement steps for MaskGit sampling. + noise_scale (float, optional): Scale factor for noise to add to confidence scores + during sampling (experimental). + fixed_schedule (list, optional): Fixed schedule for number of tokens to unmask at each step. + If None, uses cosine schedule. + dynamic_cfg_scale (bool, optional): Whether to dynamically adjust CFG scale during + sampling (experimental). + sampling_type (str, optional): Type of sampling strategy. Options are: + ["default", "causal", "purity_causal", "purity_default"]. + * Purity refers to "purity sampling" from https://arxiv.org/abs/2304.01515. If "purity" + is not specified, confidence sampling is used as in the original MaskGit paper. + * "default"/"causal": Controls the order of unmasking across frames when frame-stacking is enabled. + If "causal" is specified, frames are unmasked in causal order. "default" + doesn't impose any constraints on the unmasking order. + forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire + batch. + + Returns: + torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor) + """ + # dec_output: (B, E) + device = dec_output.device + # disable KV cache since our transformer is not causal + self.local_transformer.reset_cache(use_cache=False) + dec_output = dec_output.unsqueeze(1) # (B, 1, E) + local_transformer_input_init = self.local_transformer_in_projection( + dec_output + ) # (B, 1, D) where D is the dimension of the local transformer + codebook_seq_len = self.num_audio_codebooks * self.frame_stacking_factor + B = dec_output.size(0) + + min_confidence = 0 + # this needs to be large enough that unmasked items will always remain unmasked (even after noise addition) + # Setting it smaller could allow "regret", i.e. re-masking a codebook that was previously unmasked; we might want to try that + max_confidence = 5 + confidences = min_confidence * torch.ones(B, codebook_seq_len, device=device) + # initialize to all masked + codes = self.mask_token_id * torch.ones((B, codebook_seq_len), device=device, dtype=torch.long) + sampled_codes = codes.clone() + topk_indices = None + if fixed_schedule is not None: + n_steps = len(fixed_schedule) + for step in range(n_steps): + # how far along we are in the unmasking process + progress = step / n_steps + # get mask fraction + frac_masked = cosine_schedule(torch.tensor(progress)) + if sampling_type == "causal" or sampling_type == "purity_causal": + frac_masked = torch.ones_like(frac_masked) * (1.0 - progress) + # how many codebooks to mask + if fixed_schedule is None: + n_masked = torch.ceil(codebook_seq_len * frac_masked).long() + else: + n_masked = codebook_seq_len - fixed_schedule[step] + n_unmasked = codebook_seq_len - n_masked + + if ( + sampling_type == "causal" or sampling_type == "purity_causal" + ): # and n_unmasked <= self.num_audio_codebooks: + # force second frame not to be unmasked + n_frames_to_allow = int(np.floor(progress * self.frame_stacking_factor + 1)) + confidences[:, n_frames_to_allow * self.num_audio_codebooks :] = ( + min_confidence - 1 + ) # only tested for frame_stacking_factor=2 + + # pick top-confidence codebooks up to n_unmasked + _, topk_indices = torch.topk(confidences, k=n_unmasked, dim=1) + if use_cfg: + actual_batch_size = topk_indices.size(0) // 2 + assert ( + topk_indices[actual_batch_size:] == topk_indices[:actual_batch_size] + ).all(), "Topk indices are not the same for conditional and unconditional codes" + + # replace masks of the top-k confident codebooks with the codes that were sampled for them + unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices) + codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) + + # build transformer input + local_transformer_input = local_transformer_input_init + for codebook_num in range(codebook_seq_len): + next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze( + 1 + ) # (B, 1, 768) + next_local_transformer_input = self.local_transformer_in_projection( + next_local_transformer_input + ) # (B, 1, d_local) + local_transformer_input = torch.cat( + [local_transformer_input, next_local_transformer_input], dim=1 + ) # (B, codebook_num+1, d_local) + + # run transformer + _mask = torch.ones(B, codebook_seq_len + 1, device=device) + local_transformer_output = self.local_transformer(local_transformer_input, _mask)[ + 'output' + ] # (B, C+1, d_local) + + # get logits + logits = [] + for codebook_num in range(codebook_seq_len): + # The `codebook_num+1` is to drop first position which corresponds to the magpie latent + codebook_logits = self.local_transformer_out_projections[codebook_num]( + local_transformer_output[:, codebook_num + 1, :] + ) # (B, num_audio_tokens_per_codebook) + logits.append(codebook_logits) + logits = torch.stack(logits, dim=1) # (B, C*frame_stacking_factor, num_audio_tokens_per_codebook) + + # apply CFG + if use_cfg: + actual_batch_size = logits.size(0) // 2 + conditional_logits = logits[:actual_batch_size] + unconditional_logits = logits[actual_batch_size:] + if not dynamic_cfg_scale: + current_cfg_scale = cfg_scale + else: + # gradually increase the scale until mid point through sampling, then reduce it again + progress = step / (n_steps - 1) + # interp = -abs(progress-0.5)+0.5 # increase from 0..1 in the interval from start to midpoint and then go back to zero + # interp = 1.0 - progress # decrease from 1 to 0 + interp = progress # gradually increase from 0 to 1 + current_cfg_scale = (cfg_scale - 1) * interp + 1.0 # 1.0 --> cfg_scale --> 1.0 + cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits + logits[:actual_batch_size] = cfg_logits + + # Disallow generation of special tokens + logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos) + + # handle unfinished and finished items + for item_idx in unfinished_items: + logits[item_idx, self.audio_eos_id] = float('-inf') + for item_idx in finished_items: + logits[item_idx, :, :] = float('-inf') + logits[item_idx, :, self.audio_eos_id] = 0.0 + + # sample with top-k + logits_topk = torch.topk(logits, topk, dim=-1)[0] # (B, C, topk) + indices_to_remove = logits < logits_topk[:, :, -1].unsqueeze(-1) # (B, C, num_audio_tokens_per_codebook) + logits_rescored = logits.clone() + logits_rescored[indices_to_remove] = float('-inf') + probs = torch.softmax(logits_rescored / temperature, dim=-1) # (B, C, num_audio_tokens_per_codebook) + sampled_codes = torch.multinomial(probs.view(B * codebook_seq_len, -1), 1).view(B, codebook_seq_len) + if use_cfg: + sampled_codes[actual_batch_size:] = sampled_codes[:actual_batch_size] + probs[actual_batch_size:] = probs[:actual_batch_size] + if sampling_type != "purity_causal" and sampling_type != "purity_default": + confidences = torch.gather(probs, dim=2, index=sampled_codes.unsqueeze(-1)).squeeze(-1) + else: + # use the max probability across all tokens for each codebook as the confidence for each codebook; known as "purity sampling" + confidences = probs.max(dim=2)[0] + # replace entries in sampled_codes with previously unmasked codebooks + sampled_codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) + # add noise to confidences (as in token-critic paper, https://arxiv.org/abs/2209.04439) + if noise_scale > 0.0: + # get noise from uniform distribution in the interval [-0.5, 0.5), scale it by `noise_scale`, + # and anneal it to 0 as we approach the end of the unmasking process + noise = ( + (torch.rand_like(confidences) - 0.5) * noise_scale * (1 - (step + 2) / n_steps) + ) # the +2 makes sure that by the last iteration the noise is exactly 0 + confidences += noise + # the conditional and unconditional get different noise and must be fixed to be the same again + confidences[actual_batch_size:] = confidences[:actual_batch_size] + confidence_eps = 0.1 + assert ( + confidences.max() + confidence_eps < max_confidence + ), f"Predicted confidence is approaching max_confidence: {confidences.max()}" + # for unmasked codebooks, set confidence to max so that they will remain unmasked + confidences.scatter_( + index=topk_indices, dim=1, src=max_confidence * torch.ones_like(topk_indices, dtype=torch.float) + ) + codes = sampled_codes + assert not ( + codes == self.mask_token_id + ).any(), "Codes contain mask tokens after completion of MaskGit sampling" + + # break stacked groups of frames into individual frames + codes = codes.reshape(B, self.frame_stacking_factor, self.num_audio_codebooks).permute( + 0, 2, 1 + ) # B, C, frame_stacking_factor + + if use_cfg: + # drop unconditional codes + codes = codes[:actual_batch_size] + return codes + + def local_transformer_sample_autoregressive( + self, + dec_output: torch.Tensor, + temperature: float = 0.7, + topk: int = 80, + unfinished_items: Dict[int, bool] = {}, + finished_items: Dict[int, bool] = {}, + use_cfg: bool = False, + cfg_scale: float = 1.0, + use_kv_cache: bool = True, + forbid_audio_eos: bool = False, + ) -> torch.Tensor: + """ + Sample audio codes autoregressively across codebooks using the local + transformer. Uses multinomial sampling with temperature, top-k, and + classifier-free guidance (CFG). + + The sequence is initialized with the primary decoder's hidden output as the only + input and is gradually extended a code for one codebook at a time, appending the + sampled code as input sequence for the next step. At the last step the sequence + is `num_codebooks` long. If frame stacking is enabled, codes for all frames in + the stack are sampled as one long sequence and the final sequence length is + `num_codebooks * frame_stacking_factor` codes long. + + Special handling: + * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled + * forces / forbids EOS for finished / unfinished items respectively + * optionally, globally forbids audio EOS (useful early in the generation process) + + Args: + dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size + and E is primary decoder's embedding dimension. + temperature (float, optional): Sampling temperature. + topk (int, optional): Number of top-probability tokens to consider in sampling. + unfinished_items (dict, optional): Dictionary containing indices of batch + items that we are confident have not completed generation. For these items, audio EOS + sampling is forbidden. + finished_items (dict, optional): Dictionary containing indices of batch + items that we are confident are completed. For these items, audio EOS sampling + is forced. + use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size + to be doubled with conditional and unconditional outputs from the primary decoder. + cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True. + use_kv_cache (bool, optional): Whether to use key-value caching in the transformer. + forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire + batch. + + Returns: + torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor) + where B is batch size (or actual_batch_size if use_cfg=True). + """ + + self.local_transformer.reset_cache(use_cache=use_kv_cache) + dec_output = dec_output.unsqueeze(1) # (B, 1, E) + local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) all_preds = [] - for idx in range(self.cfg.num_audio_codebooks): - si = idx * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook - codebook_logits = all_code_logits_t[:, si:ei] # (B, num_tokens_per_codebook) + for codebook_num in range(self.num_audio_codebooks * self.frame_stacking_factor): + _mask = torch.ones( + local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device + ) + local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, T, 128) + codebook_logits = self.local_transformer_out_projections[codebook_num]( + local_transformer_output[:, -1, :] + ) # (B, num_all_tokens_per_codebook) + if use_cfg: + actual_batch_size = codebook_logits.size(0) // 2 + conditional_logits = codebook_logits[:actual_batch_size] + unconditional_logits = codebook_logits[actual_batch_size:] + cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits + codebook_logits[:actual_batch_size] = cfg_logits + + for item_idx in unfinished_items: + codebook_logits[item_idx, self.audio_eos_id] = float('-inf') + for item_idx in finished_items: + codebook_logits[item_idx, :] = float('-inf') + codebook_logits[item_idx, self.audio_eos_id] = 0.0 + + # Disallow generation of special tokens + codebook_logits = self.clear_forbidden_logits( + codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos + ).squeeze(1) + codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 ) # (B, num_tokens_per_codebook) codebook_logits_rescored = codebook_logits.clone() codebook_logits_rescored[indices_to_remove] = float('-inf') - - codebook_probs = torch.softmax(codebook_logits / temperature, dim=-1) # (B, num_tokens_per_codebook) + codebook_probs = torch.softmax( + codebook_logits_rescored / temperature, dim=-1 + ) # (B, num_tokens_per_codebook) codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) + if use_cfg: + codebook_preds[actual_batch_size:] = codebook_preds[:actual_batch_size] all_preds.append(codebook_preds) - all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks) + next_local_transformer_input = self.audio_embeddings[codebook_num](codebook_preds.squeeze(-1)).unsqueeze( + 1 + ) # (B, 1, 128) + next_local_transformer_input = self.local_transformer_in_projection( + next_local_transformer_input + ) # (B, 1, 128) + local_transformer_input = torch.cat( + [local_transformer_input, next_local_transformer_input], dim=1 + ) # (B, T+1, 128) + + all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks * frame_stacking_factor) + all_preds = all_preds.reshape(-1, self.frame_stacking_factor, self.num_audio_codebooks).permute( + 0, 2, 1 + ) # (B, num_codebooks, frame_stacking_factor) + if use_cfg: + all_preds = all_preds[:actual_batch_size] + + return all_preds + + def sample_codes_from_logits( + self, + all_code_logits_t: torch.Tensor, + temperature: float = 0.7, + topk: int = 80, + unfinished_items: Dict[int, bool] = {}, + finished_items: Dict[int, bool] = {}, + forbid_audio_eos: bool = False, + ) -> torch.Tensor: + """ + Sample codes for all codebooks at a given timestep. Uses multinomial sampling + with temperature and top-k. If frame stacking is on (i.e. `frame_stacking_factor + > 1`), this function will sample across the entire frame stack. + + Special handling: + * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled + * forces / forbids EOS for finished / unfinished items respectively + * optionally, globally forbids audio EOS (useful early in the generation process) + + Args: + all_code_logits_t (torch.Tensor): Logits at a given timestep with shape + (B, num_tokens_per_codebook * num_codebooks * frame_stacking_factor) + temperature (float, optional): Sampling temperature + topk (int, optional): Number of top-probability tokens to consider in sampling. + unfinished_items (dict, optional): Dictionary containing indices of batch + items that we are confident have not completed generation. For these items, audio EOS + sampling is forbidden. + finished_items (dict, optional): Dictionary containing indices of batch + items that we are confident are completed. For these items, audio EOS sampling + is forced. + forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire + batch. + + Returns: + torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor). + """ + all_preds = [[] for _ in range(self.frame_stacking_factor)] + for fs_index in range(self.frame_stacking_factor): + for idx in range(self.num_audio_codebooks): + si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook + codebook_logits = all_code_logits_t[:, si:ei] # (B, num_tokens_per_codebook) + + for item_idx in unfinished_items: + codebook_logits[item_idx, self.audio_eos_id] = float('-inf') + for item_idx in finished_items: + codebook_logits[item_idx, :] = float('-inf') + codebook_logits[item_idx, self.audio_eos_id] = 0.0 + + # Disallow generation of special tokens + codebook_logits = self.clear_forbidden_logits( + codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos + ).squeeze(1) + + codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) + indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( + -1 + ) # (B, num_tokens_per_codebook) + codebook_logits_rescored = codebook_logits.clone() + codebook_logits_rescored[indices_to_remove] = float('-inf') + + codebook_probs = torch.softmax( + codebook_logits_rescored / temperature, dim=-1 + ) # (B, num_tokens_per_codebook) + codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) + all_preds[fs_index].append(codebook_preds) + + all_preds = [ + torch.cat(ds_preds, dim=1).long() for ds_preds in all_preds + ] # list of `frame_stacking_factor` elements, each of shape (B, num_codebooks) + all_preds = torch.stack(all_preds, dim=2) # (B, num_codebooks, frame_stacking_factor) return all_preds def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens, prefix="", dec_context_size=0): # attention_prob_matrix List of (B, C, audio_timesteps, text_timesteps) + wandb_images_log = {} + with torch.no_grad(): attention_prob_matrix = torch.cat(attention_prob_matrix, dim=1) # (B, C, audio_timesteps, text_timesteps) attention_prob_matrix_mean = attention_prob_matrix.mean(dim=1) # (B, audio_timesteps, text_timesteps) - for idx in range(min(3, attention_prob_matrix_mean.size(0))): - item_attn_matrix = attention_prob_matrix_mean[idx][ - dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx] - ] - item_attn_matrix = item_attn_matrix.detach().cpu().numpy() - attn_np = plot_alignment_to_numpy(item_attn_matrix.T) - self.tb_logger.add_image( - f'{prefix}attention_matrix_{idx}', - attn_np, - global_step=self.global_step, - dataformats="HWC", - ) - def log_train_val_example( + for logger in self.loggers: + is_wandb = isinstance(logger, WandbLogger) + is_tb = isinstance(logger, TensorBoardLogger) + if not is_wandb and not is_tb: + raise ValueError( + f"Invalid logger type for image logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." + ) + + wandb_images_log[f"Image/{prefix}/attention_matrix"] = list() + for idx in range(min(3, attention_prob_matrix_mean.size(0))): + item_attn_matrix = attention_prob_matrix_mean[idx][ + dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx] + ] + item_attn_matrix = item_attn_matrix.detach().cpu().numpy() + img_np = plot_alignment_to_numpy(item_attn_matrix.T) + + if is_wandb: + wandb_images_log[f"Image/{prefix}/attention_matrix"].append( + wandb.Image(img_np, caption=f"Example_{idx}") + ) + + if is_tb: + logger.experiment.add_image( + f'{prefix}/attention_matrix/Example_{idx}', + img_np, + global_step=self.global_step, + dataformats="HWC", + ) + + return wandb_images_log + + def log_val_audio_example( self, logits, target_audio_codes, @@ -420,60 +1229,102 @@ def log_train_val_example( context_audio_codes=None, context_audio_codes_lens=None, ): + wandb_audio_log = {} + pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens_target) pred_audio, pred_audio_lens = self.codes_to_audio(pred_audio_codes, audio_codes_lens_target) target_audio, target_audio_lens = self.codes_to_audio(target_audio_codes, audio_codes_lens_target) + context_audio, context_audio_lens = None, None if context_audio_codes is not None and context_audio_codes.shape[2] > 3: # > 3 ensures, it is a valid context audio tensor (and not dummy tensor used in text context) context_audio, context_audio_lens = self.codes_to_audio(context_audio_codes, context_audio_codes_lens) - for idx in range(min(3, pred_audio.size(0))): - pred_audio_np = pred_audio[idx].float().detach().cpu().numpy() - target_audio_np = target_audio[idx].float().detach().cpu().numpy() - pred_audio_np = pred_audio_np[: pred_audio_lens[idx]] - target_audio_np = target_audio_np[: target_audio_lens[idx]] - self.tb_logger.add_audio( - f'pred_audio_{idx}', - pred_audio_np, - global_step=self.global_step, - sample_rate=self.cfg.sample_rate, - ) - self.tb_logger.add_audio( - f'target_audio_{idx}', - target_audio_np, - global_step=self.global_step, - sample_rate=self.cfg.sample_rate, - ) - if context_audio is not None: - context_audio_np = context_audio[idx].float().detach().cpu().numpy() - context_audio_np = context_audio_np[: context_audio_lens[idx]] - self.tb_logger.add_audio( - f'context_audio_{idx}', - context_audio_np, - global_step=self.global_step, - sample_rate=self.cfg.sample_rate, + + for logger in self.loggers: + is_wandb = isinstance(logger, WandbLogger) + is_tb = isinstance(logger, TensorBoardLogger) + if not is_wandb and not is_tb: + raise ValueError( + f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." ) + for idx in range(min(3, pred_audio.size(0))): + pred_audio_np = pred_audio[idx].float().detach().cpu().numpy() + target_audio_np = target_audio[idx].float().detach().cpu().numpy() + pred_audio_np = pred_audio_np[: pred_audio_lens[idx]] + target_audio_np = target_audio_np[: target_audio_lens[idx]] + context_audio_np = None + if context_audio is not None: + context_audio_np = context_audio[idx].float().detach().cpu().numpy() + context_audio_np = context_audio_np[: context_audio_lens[idx]] + + if is_wandb: + wandb_audio_log[f"Audio/Example_{idx}"] = list() + if context_audio_np is not None: + wandb_audio_log[f"Audio/Example_{idx}"].append( + wandb.Audio(context_audio_np, sample_rate=self.sample_rate, caption="context") + ) + wandb_audio_log[f"Audio/Example_{idx}"].append( + wandb.Audio(pred_audio_np, sample_rate=self.sample_rate, caption="prediction") + ) + wandb_audio_log[f"Audio/Example_{idx}"].append( + wandb.Audio(target_audio_np, sample_rate=self.sample_rate, caption="target") + ) + + if is_tb: + if context_audio_np is not None: + logger.experiment.add_audio( + f'Example_{idx}/context', + context_audio_np, + global_step=self.global_step, + sample_rate=self.sample_rate, + ) + logger.experiment.add_audio( + f'Example_{idx}/prediction', + pred_audio_np, + global_step=self.global_step, + sample_rate=self.sample_rate, + ) + logger.experiment.add_audio( + f'Example_{idx}/target', + target_audio_np, + global_step=self.global_step, + sample_rate=self.sample_rate, + ) + + return wandb_audio_log + def scale_prior(self, prior, global_step): if prior is None: return None - prior_end_step = self.cfg.prior_end_step - prior_scaledown_start_step = self.cfg.prior_scaledown_start_step - if global_step < prior_scaledown_start_step: + if global_step < self.prior_scaledown_start_step: return prior - elif global_step >= prior_end_step: - return None + elif global_step >= self.prior_end_step: + if random.random() < self.indefinite_prior_prob: + print("Using Prior") + return prior + else: + print("Not using Prior") + return None else: with torch.no_grad(): # Interpolate between all ones and the prior residual = 1.0 - prior new_prior = prior + ( residual - * (global_step - prior_scaledown_start_step) - / (prior_end_step - prior_scaledown_start_step) + * (global_step - self.prior_scaledown_start_step) + / (self.prior_end_step - self.prior_scaledown_start_step) ) return new_prior + def embed_text(self, text, text_mask): + if self.use_bpe_char_tokenizer: + text_embedded = self.cas_encoder(text, subword_mask=text_mask) + else: + text_embedded = self.text_embedding(text) + + return text_embedded + def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_context_size=0): # attention scores: List of (B, C, audio_timesteps, text_timesteps) attention_scores_combined = torch.cat(attention_scores, dim=1) # (B, C, audio_timesteps, text_timesteps) @@ -488,10 +1339,44 @@ def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_co ) return alignment_loss + def pad_audio_codes(self, audio_codes: torch.Tensor, frame_stacking_factor: int = 1, pad_token: int = 0): + """ + Pads the time dimension of the audio codes to a multiple of the frame stacking factor. + Args: + audio_codes (torch.Tensor): B, C, T + frame_stacking_factor (int): The factor that frames will be stacked by. + pad_token (int): The token ID to pad with. + Returns: + B, C, T_padded + """ + T = audio_codes.size(2) + T_padded = int(np.ceil(T / frame_stacking_factor) * frame_stacking_factor) + if T_padded > T: + padding = pad_token * torch.ones( + audio_codes.size(0), + audio_codes.size(1), + T_padded - T, + device=audio_codes.device, + dtype=audio_codes.dtype, + ) + audio_codes = torch.cat([audio_codes, padding], dim=2) + return audio_codes + + def embed_context_text(self, context_text_tokens): + if self.legacy_text_conditioning: + context_text_tokens = ( + context_text_tokens - self.tokenizer.tokenizer_offsets[self.text_conditioning_tokenizer_name] + ) + context_text_embedded = self.context_text_embedding(context_text_tokens) # (B, L, E) + else: + context_text_embedded = self.text_embedding(context_text_tokens) # (B, L, E) + + return context_text_embedded + def prepare_context_tensors(self, batch): dec_context_size = 0 additional_decoder_input = None - addtional_decoder_mask = None + additional_decoder_mask = None context_audio_codes = None context_audio_codes_lens = None _attn_prior = None @@ -502,40 +1387,35 @@ def prepare_context_tensors(self, batch): text = None text_lens = None - # self.model_type must be one of - # [single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts, decoder_pretrain_synthesizer] - if self.model_type != 'decoder_pretrain_synthesizer': - text = batch['text'] - text_lens = batch['text_lens'] - text_embedded = self.text_embedding(text) # (B, T, E) - text_mask = get_mask_from_lengths(text_lens) # (B, T) - text_encoder_out = self.encoder(text_embedded, text_mask, cond=None, cond_mask=None)['output'] # (B, T, E) - _attn_prior = batch.get('align_prior_matrix', None) - _attn_prior = self.scale_prior(_attn_prior, self.global_step) - - if self.model_type == 'single_encoder_sv_tts': - target_audio_16khz = batch['audio_16khz'] - target_audio_lens_16khz = batch['audio_lens_16khz'] - speaker_embeddings = self.get_speaker_embeddings(target_audio_16khz, target_audio_lens_16khz) - speaker_embeddings_projected = self.speaker_projection_layer(speaker_embeddings) - cond = text_encoder_out + speaker_embeddings_projected.unsqueeze(1) - cond_mask = text_mask - multi_encoder_mapping = None - attn_prior = _attn_prior - elif self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts']: + # self.model_type must be one of [multi_encoder_context_tts, decoder_context_tts, decoder_ce] + text = batch['text'] + text_lens = batch['text_lens'] + text_mask = get_mask_from_lengths(text_lens) # (B, T) + text_embedded = self.embed_text(text, text_mask) # (B, T, E) + text_encoder_out = self.encoder(text_embedded, text_mask, cond=None, cond_mask=None)['output'] # (B, T, E) + _attn_prior = batch.get('align_prior_matrix', None) + _attn_prior = self.scale_prior(_attn_prior, self.global_step) + + if self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts', 'decoder_ce']: if 'context_audio_codes' in batch: context_audio_codes = batch['context_audio_codes'] context_audio_codes_lens = batch['context_audio_codes_lens'] + if self._codec_converter is not None: + context_audio_codes = self._codec_converter.convert_original_to_new( + audio_tokens=context_audio_codes, audio_lens=context_audio_codes_lens + ).long() else: context_audio_codes, context_audio_codes_lens = self.audio_to_codes( batch['context_audio'], batch['context_audio_lens'], audio_type='context' ) - context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T', E) + context_audio_codes = self.pad_audio_codes(context_audio_codes, self.frame_stacking_factor, pad_token=0) + context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T/frame_stacking_factor, E) if self.use_text_conditioning_encoder: context_text_tokens = batch['context_text_tokens'] context_text_lens = batch['context_text_tokens_lens'] - context_text_embedded = self.context_text_embedding(context_text_tokens) # (B, L, E) + context_text_embedded = self.embed_context_text(context_text_tokens) # (B, L, E) + # Pad context_audio_embedded or context_text_embedded so that they have same number of timesteps if context_audio_embedded.size(1) < context_text_embedded.size(1): padding = torch.zeros( @@ -564,6 +1444,9 @@ def prepare_context_tensors(self, batch): else: context_input_embedded = context_audio_embedded context_input_lens = context_audio_codes_lens + context_input_lens = torch.ceil(context_input_lens / self.frame_stacking_factor).to( + context_input_lens.dtype + ) context_mask = get_mask_from_lengths(context_input_lens) @@ -576,9 +1459,15 @@ def prepare_context_tensors(self, batch): multi_encoder_mapping = self.multi_encoder_mapping attn_prior = [_attn_prior, None] - elif self.model_type == 'decoder_context_tts': + elif self.model_type in ['decoder_context_tts', 'decoder_ce']: dec_context_size = context_mask.size(1) - context_embeddings = context_input_embedded + context_embeddings = None # Address CodeQL + if self.model_type == 'decoder_context_tts': + context_embeddings = context_input_embedded + elif self.model_type == 'decoder_ce': + context_embeddings = self.context_encoder( + context_input_embedded, context_mask, cond=None, cond_mask=None + )['output'] attn_prior = _attn_prior if attn_prior is not None: # B, audio_timesteps, text_timesteps @@ -590,26 +1479,114 @@ def prepare_context_tensors(self, batch): cond_mask = text_mask multi_encoder_mapping = None additional_decoder_input = context_embeddings - addtional_decoder_mask = context_mask - elif self.model_type == 'decoder_pretrain_synthesizer': - pass + additional_decoder_mask = context_mask else: raise ValueError(f"Unsupported model type {self.model_type}") + if attn_prior is not None and self.ctc_prior_layer_ids is not None: + # Convert prior to a list of tensors, one for each layer + # Set None for layers not in ctc_prior_layer_ids + if self.model_type == 'multi_encoder_context_tts': + text_attn_prior = [ + attn_prior[0] if layer_idx in self.ctc_prior_layer_ids else None + for layer_idx in range(self.decoder.n_layers) + ] + attn_prior = [text_attn_prior, attn_prior[1]] + else: + attn_prior = [ + attn_prior if layer_idx in self.ctc_prior_layer_ids else None + for layer_idx in range(self.decoder.n_layers) + ] + return { + 'beta_binomial_attn_prior': batch.get('align_prior_matrix', None), + 'text_encoder_out': text_encoder_out, 'cond': cond, 'cond_mask': cond_mask, 'attn_prior': attn_prior, + 'prior_used': _attn_prior is not None, 'multi_encoder_mapping': multi_encoder_mapping, 'additional_decoder_input': additional_decoder_input, - 'addtional_decoder_mask': addtional_decoder_mask, + 'additional_decoder_mask': additional_decoder_mask, 'dec_context_size': dec_context_size, 'text': text, + 'text_embedded': text_embedded, + 'text_mask': text_mask, 'text_lens': text_lens, 'context_audio_codes': context_audio_codes, 'context_audio_codes_lens': context_audio_codes_lens, } + def replace_beta_binomial_prior_with_binarized(self, attn_prior, aligner_attn_hard): + # aligner_attn_hard B, audio_timesteps, text_timesteps + if self.model_type == 'multi_encoder_context_tts': + text_attn_prior = attn_prior[0] + else: + text_attn_prior = attn_prior + + assert text_attn_prior is not None, "Prior is None" + + if isinstance(text_attn_prior, list): + # Layer wise prior + prior_updated = False + for idx, prior in enumerate(text_attn_prior): + if prior is not None: + text_attn_prior[idx][:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard + prior_updated = True + assert prior_updated, "Did not find any prior to update" + else: + # Same prior for all layers + text_attn_prior[:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard + + if self.model_type == 'multi_encoder_context_tts': + attn_prior[0] = text_attn_prior + else: + attn_prior = text_attn_prior + + return attn_prior + + def get_binarized_prior_matrix(self, aligner_attn_soft, audio_lens, text_lens): + # aligner_attn_soft B, 1, audio_timesteps, text_timesteps + if self.binarize_attn_method == 'nemo_binarize': + logging.debug("Binarizing attention using nemo_binarize") + binarize_repeat_audio_factor = self.binarize_repeat_audio_factor + aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave( + binarize_repeat_audio_factor, dim=2 + ) # B, 1, 2*audio_timesteps, text_timesteps + aligner_attn_hard = binarize_attention_parallel( + aligner_attn_soft_repeated, text_lens, audio_lens * binarize_repeat_audio_factor + ).squeeze( + 1 + ) # B, 2*audio_timesteps, text_timesteps + aligner_attn_hard = aligner_attn_hard[:, ::2, :] # B, audio_timesteps, text_timesteps + elif self.binarize_attn_method == 'argmax': + logging.debug("Binarizing attention using argmax") + aligner_attn_hard = torch.argmax(aligner_attn_soft.squeeze(1), dim=-1) + aligner_attn_hard = torch.nn.functional.one_hot( + aligner_attn_hard, num_classes=aligner_attn_soft.size(-1) + ).float() + else: + raise ValueError( + f"self.binarize_attn_method '{self.binarize_attn_method}' must be one of 'nemo_binarize' or 'argmax'." + ) + + aligner_attn_hard_wider = aligner_attn_hard + self.binarized_prior_epsilon + + for future_timestep in range(self.prior_future_context): + decay_factor = self.prior_future_decay ** (future_timestep + 1) + aligner_attn_hard_wider[:, :, future_timestep + 1 :] += ( + decay_factor * aligner_attn_hard[:, :, : -(future_timestep + 1)] + ) + + for past_timestep in range(self.prior_past_context): + decay_factor = self.prior_past_decay ** (past_timestep + 1) + aligner_attn_hard_wider[:, :, : -past_timestep - 1] += ( + decay_factor * aligner_attn_hard[:, :, past_timestep + 1 :] + ) + + aligner_attn_hard_wider = torch.clamp(aligner_attn_hard_wider, 0.0, 1.0) + return aligner_attn_hard_wider + def prepare_dummy_cond_for_cfg(self, cond, cond_mask, additional_decoder_input, additional_dec_mask): dummy_additional_decoder_input = None dummy_additional_dec_mask = None @@ -648,24 +1625,54 @@ def process_batch(self, batch, mode="train"): else: audio_codes = batch['audio_codes'] audio_codes_lens = batch['audio_codes_lens'] - - audio_codes_input = audio_codes[:, :, :-1] # B, C, T' - audio_codes_target = audio_codes[:, :, 1:] - audio_codes_lens_input = audio_codes_lens_target = audio_codes_lens - 1 + if self._codec_converter: + audio_codes = self._codec_converter.convert_original_to_new( + audio_tokens=audio_codes, audio_lens=audio_codes_lens + ).long() + if self.frame_stacking_factor > 1: + # repeat the BOS token to frame_stacking_factor times. This is necessary since at inference + # we need to start autoregressive generation from a full stack indicating BOS. + # TODO: @rfejgin: this assert might be slow due to GPU/CPU sync + assert (audio_codes[:, :, 0] == self.audio_bos_id).all(), "Audio codes do not start with BOS token" + audio_codes = torch.cat( + [ + torch.full( + (audio_codes.size(0), audio_codes.size(1), self.frame_stacking_factor - 1), + self.audio_bos_id, + device=audio_codes.device, + dtype=audio_codes.dtype, + ), + audio_codes, + ], + dim=2, + ) + audio_codes_lens += self.frame_stacking_factor - 1 # account for BOS repeat + audio_codes = self.pad_audio_codes(audio_codes, self.frame_stacking_factor, pad_token=0) + # Note: if a tensor lacks the `_unstacked` suffix, it can be assumed to to be in the frame-stacked domain + + # drop last (stacked) frame since it is not part of *input* + audio_codes_input_unstacked = audio_codes[:, :, : -self.frame_stacking_factor] # B, C, T' + # drop first (stacked) frame which contains BOS token(s) which are not part of *target* + audio_codes_target_unstacked = audio_codes[:, :, self.frame_stacking_factor :] + audio_codes_lens_input_unstacked = audio_codes_lens - 1 # don't count EOS for input + audio_codes_lens_target_unstacked = audio_codes_lens - self.frame_stacking_factor # don't count BOS for target + audio_codes_lens_input = torch.floor(audio_codes_lens_input_unstacked / self.frame_stacking_factor).long() + audio_codes_embedded_all = self.embed_audio_tokens( + audio_codes + ) # (B, T, E) # Computing this to be use in the alignment encoder + audio_codes_embedded = audio_codes_embedded_all[ + :, :-1, : + ] # (B, T', E) Input to the decoder; this is already in the frame-stacked domain, hence the -1 (not `frame_stacking_factor`) audio_codes_mask = get_mask_from_lengths(audio_codes_lens_input) - use_cfg = ( - (self.cfg.get('cfg_unconditional_prob', 0.0) > 0.0) - and (mode == "train") - and (context_tensors['cond'] is not None) - ) - if use_cfg and torch.rand(1).item() < self.cfg.cfg_unconditional_prob: + use_cfg = (self.cfg_unconditional_prob > 0.0) and (mode == "train") and (context_tensors['cond'] is not None) + if use_cfg and torch.rand(1).item() < self.cfg_unconditional_prob: cond, cond_mask, additional_decoder_input, additional_decoder_mask, attn_prior = ( self.prepare_dummy_cond_for_cfg( context_tensors['cond'], context_tensors['cond_mask'], context_tensors['additional_decoder_input'], - context_tensors['addtional_decoder_mask'], + context_tensors['additional_decoder_mask'], ) ) disable_alignment_loss = True @@ -673,31 +1680,29 @@ def process_batch(self, batch, mode="train"): cond = context_tensors['cond'] cond_mask = context_tensors['cond_mask'] additional_decoder_input = context_tensors['additional_decoder_input'] - additional_decoder_mask = context_tensors['addtional_decoder_mask'] + additional_decoder_mask = context_tensors['additional_decoder_mask'] attn_prior = context_tensors['attn_prior'] - if ( - mode == "train" - and self.cfg.get('decoder_input_dropout_prob', 0.0) > 0.0 - and torch.rand(1).item() < 0.5 - ): + if mode == "train" and self.decoder_input_dropout_prob > 0.0 and torch.rand(1).item() < 0.5: # For some batches (half of them), replace decoder_input_dropout_prob of the timesteps with random tokens - max_codebook_val = self.cfg.get('dec_random_input_max', self.cfg.num_audio_tokens_per_codebook) - # @pneekhara: Keeping dec_random_input_max configurable since num_audio_tokens_per_codebook usually has padding tokens + max_codebook_val = self.dec_random_input_max + # @pneekhara: Keeping dec_random_input_max configurable since num_all_tokens_per_codebook usually has padding tokens # which can cause errors when doing codes_to_audio for audio_codes_input. We are not currently calling codes_to_audio on - # audio_codes_input so should not matter if we dont supply dec_random_input_max. + # audio_codes_input so should not matter if we don't supply dec_random_input_max. random_audio_tokens = torch.randint( - 0, max_codebook_val, audio_codes_input.size(), device=audio_codes_input.device + 0, max_codebook_val, audio_codes_input_unstacked.size(), device=audio_codes_input_unstacked.device ) random_audio_tokens = random_audio_tokens * audio_codes_mask.unsqueeze(1) dec_dropout_mask = ( - torch.rand((1, 1, audio_codes_input.size(2)), device=audio_codes_input.device) - > self.cfg.decoder_input_dropout_prob + torch.rand((1, 1, audio_codes_input_unstacked.size(2)), device=audio_codes_input_unstacked.device) + > self.decoder_input_dropout_prob ) # timestep_mask is True for timesteps to be kept - audio_codes_input = audio_codes_input * dec_dropout_mask + random_audio_tokens * (~dec_dropout_mask) + audio_codes_input_unstacked = audio_codes_input_unstacked * dec_dropout_mask + random_audio_tokens * ( + ~dec_dropout_mask + ) + audio_codes_embedded = self.embed_audio_tokens(audio_codes_input_unstacked) # (B, T', E) - audio_codes_embedded = self.embed_audio_tokens(audio_codes_input) # (B, T', E) if context_tensors['additional_decoder_input'] is not None: dec_input_embedded = torch.cat([additional_decoder_input, audio_codes_embedded], dim=1) dec_input_mask = torch.cat([additional_decoder_mask, audio_codes_mask], dim=1) @@ -705,7 +1710,45 @@ def process_batch(self, batch, mode="train"): dec_input_embedded = audio_codes_embedded dec_input_mask = audio_codes_mask - logits, attn_info = self.forward( + aligner_encoder_loss = None + aligner_attn_soft = None + aligner_attn_hard = None + if self.use_alignment_encoder and not disable_alignment_loss: + aligner_prior = None + if self.use_prior_for_aligner: + aligner_prior = context_tensors['beta_binomial_attn_prior'] + # Passing target audio embeddings to the alignment encoder + if self.global_step < self.aligner_encoder_train_steps: + aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder( + queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' + keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T + mask=~context_tensors['text_mask'].unsqueeze(-1), + attn_prior=aligner_prior, + ) + + aligner_encoder_loss = self.alignment_encoder_loss( + attn_logprob=aligner_attn_logprobs, + in_lens=context_tensors['text_lens'], + out_lens=audio_codes_lens_input, + ) + else: + with torch.no_grad(): + # Just get the attention matrix without computing the loss or gradients + aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder( + queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' + keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T + mask=~context_tensors['text_mask'].unsqueeze(-1), + attn_prior=aligner_prior, + ) + + with torch.no_grad(): + aligner_attn_hard = self.get_binarized_prior_matrix( + aligner_attn_soft, audio_codes_lens_input, context_tensors['text_lens'] + ) + if (self.global_step > self.binarize_prior_after_step) and context_tensors['prior_used']: + attn_prior = self.replace_beta_binomial_prior_with_binarized(attn_prior, aligner_attn_hard) + + logits, attn_info, dec_out = self.forward( dec_input_embedded=dec_input_embedded, dec_input_mask=dec_input_mask, cond=cond, @@ -714,61 +1757,154 @@ def process_batch(self, batch, mode="train"): multi_encoder_mapping=context_tensors['multi_encoder_mapping'], ) # logits: (B, T', num_codebooks * num_tokens_per_codebook) + # dec_out: (B, T', E) dec_context_size = context_tensors['dec_context_size'] logits = logits[:, dec_context_size:, :] # Remove the context audio embeddings from the logits - codebook_loss, loss_mask = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target) + # Codebook loss (parallel) + codebook_loss, loss_mask = self.compute_loss( + logits, + audio_codes_target_unstacked, + audio_codes_lens_target_unstacked, + frame_stacking_factor=self.frame_stacking_factor, + ) + # Alignment loss alignment_loss = None - if self.cfg.alignment_loss_scale > 0.0 and not disable_alignment_loss: + if self.alignment_loss_scale > 0.0 and not disable_alignment_loss: text_lens = context_tensors['text_lens'] cross_attention_scores = [ attn['cross_attn_probabilities'][1] for layer_idx, attn in enumerate(attn_info) - if layer_idx in self.transcript_decoder_layers + if layer_idx in self.ctc_prior_layer_ids ] alignment_loss = self.compute_alignment_loss( - cross_attention_scores, text_lens, audio_codes_lens_target, dec_context_size + cross_attention_scores, text_lens, audio_codes_lens_input, dec_context_size ) - loss = codebook_loss + alignment_loss + loss = self.codebook_loss_scale * codebook_loss + alignment_loss else: - loss = codebook_loss + loss = self.codebook_loss_scale * codebook_loss + + # Local Transformer loss + local_transformer_loss = None + local_transformer_logits = None + if self.local_transformer_type != LocalTransformerType.NO_LT: + if self.local_transformer_type == LocalTransformerType.MASKGIT: + # Maskgit + # randomly replace some positions with MASK_TOKEN + audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target_unstacked) + # TODO @rfejgin: the very last position might be padding but the local transformer might look at it as part of + # of a pair where the first position is valid. Is this an issue? + local_transformer_logits = self.compute_local_transformer_logits( + dec_out[:, dec_context_size:, :], audio_codes_masked, targets_offset_by_one=True + ) + local_transformer_loss, _ = self.compute_loss( + local_transformer_logits, + audio_codes_target_unstacked, + audio_codes_lens_target_unstacked, + mask_tokens_mask, + frame_stacking_factor=self.frame_stacking_factor, + ) + else: + # Autoregressive + assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type" + local_transformer_logits = self.compute_local_transformer_logits( + dec_out[:, dec_context_size:, :], audio_codes_target_unstacked, targets_offset_by_one=False + ) + local_transformer_loss, _ = self.compute_loss( + local_transformer_logits, + audio_codes_target_unstacked, + audio_codes_lens_target_unstacked, + None, + frame_stacking_factor=self.frame_stacking_factor, + ) + loss = loss + self.local_transformer_loss_scale * local_transformer_loss + + if aligner_encoder_loss is not None: + loss = loss + aligner_encoder_loss return { 'logits': logits, 'attn_info': attn_info, 'loss': loss, 'codebook_loss': codebook_loss, + 'local_transformer_loss': local_transformer_loss, + 'local_transformer_logits': local_transformer_logits, 'loss_mask': loss_mask, 'alignment_loss': alignment_loss, - 'audio_codes_target': audio_codes_target, - 'audio_codes_lens_target': audio_codes_lens_target, + 'aligner_encoder_loss': aligner_encoder_loss, + 'audio_codes_target': audio_codes_target_unstacked, + 'audio_codes_lens_target': audio_codes_lens_target_unstacked, 'text': context_tensors['text'], 'text_lens': context_tensors['text_lens'], 'context_audio_codes': context_tensors['context_audio_codes'], 'context_audio_codes_lens': context_tensors['context_audio_codes_lens'], 'dec_context_size': dec_context_size, + 'aligner_attn_soft': aligner_attn_soft, + 'aligner_attn_hard': aligner_attn_hard, } def training_step(self, batch, batch_idx): batch_output = self.process_batch(batch) loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] - self.log('train_codebook_loss', codebook_loss, prog_bar=True, sync_dist=True) - if self.cfg.get('cfg_unconditional_prob', 0.0) == 0.0: + self.log('train/codebook_loss', codebook_loss, prog_bar=True, sync_dist=True) + if self.cfg_unconditional_prob == 0.0: # Only log alignment loss when not using cfg to avoid sync issues when # alignment loss is None on some ranks alignment_loss = batch_output['alignment_loss'] if alignment_loss is not None: - self.log('train_alignment_loss', alignment_loss, prog_bar=True, sync_dist=True) - self.log('train_loss', loss, prog_bar=True, sync_dist=True) + self.log('train/alignment_loss', alignment_loss, prog_bar=True, sync_dist=True) + self.log('train/loss', loss, prog_bar=True, sync_dist=True) + local_transformer_loss = batch_output['local_transformer_loss'] + if local_transformer_loss is not None: + self.log('train/local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True) + + # Log batch info + batch_size, text_token_max_len = batch["text"].shape + text_token_total_num = batch["text_lens"].sum() + batch_info_dict = { + "train/batch_size": batch_size, + "train/text_token_max_len": text_token_max_len, + "train/text_token_total_num_in_batch": text_token_total_num.item(), + "train/text_token_pad_ratio_percent_in_batch": 100 + * (1 - text_token_total_num / (batch_size * text_token_max_len)), + } + + if "audio_codes" in batch: + audio_codes_max_len = batch["audio_codes"].shape[-1] + audio_codes_total_num = batch["audio_codes_lens"].sum() + batch_info_dict.update( + { + "train/audio_codes_max_len": audio_codes_max_len, + "train/audio_codes_total_num_in_batch": audio_codes_total_num.item(), + "train/audio_codes_pad_ratio_percent_in_batch": 100 + * (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)), + } + ) + else: + audio_samples_max_len = batch["audio"].shape[-1] + audio_samples_total_num = batch["audio_lens"].sum() + batch_info_dict.update( + { + "train/audio_samples_max_len": audio_samples_max_len, + "train/audio_samples_total_num_in_batch": audio_samples_total_num.item(), + "train/audio_samples_pad_ratio_percent_in_batch": 100 + * (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)), + } + ) + + self.log_dict(batch_info_dict, on_step=True) return loss def validation_step(self, batch, batch_idx): batch_output = self.process_batch(batch, mode="val") + # self.process_batch returns a dict. We currently only log "logits" which come from the parallel prediction + # head. If we use local_transformer, then the local_transformer returns "local_transformer_logits" loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] alignment_loss = batch_output['alignment_loss'] + aligner_encoder_loss = batch_output['aligner_encoder_loss'] logits = batch_output['logits'] audio_codes_target = batch_output['audio_codes_target'] audio_codes_lens_target = batch_output['audio_codes_lens_target'] @@ -779,48 +1915,342 @@ def validation_step(self, batch, batch_idx): dec_context_size = batch_output['dec_context_size'] if alignment_loss is None: alignment_loss = torch.tensor(0.0, device=loss.device) + if aligner_encoder_loss is None: + aligner_encoder_loss = torch.tensor(0.0, device=loss.device) if batch_idx == 0 and self.global_rank == 0: - self.log_train_val_example( - logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens + # Prepare dictionary for aggregated wandb logging + wandb_log_dict = {} + + # Get audio data for logging + wandb_log_dict.update( + self.log_val_audio_example( + logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens + ) ) - if ( - self.model_type != 'decoder_pretrain_synthesizer' - and len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1 - ): + + # Get attention image data for logging + if len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1: # cross_attn_probabilities only returned when not using flash attention cross_attention_probs = [ attn['cross_attn_probabilities'][0] for layer_idx, attn in enumerate(attn_info) - if layer_idx in self.transcript_decoder_layers + if layer_idx in self.ctc_prior_layer_ids ] - self.log_attention_probs( - cross_attention_probs, - audio_codes_lens_target, - text_lens, - prefix="val_", - dec_context_size=dec_context_size, + wandb_log_dict.update( + self.log_attention_probs( + cross_attention_probs, + audio_codes_lens_target, + text_lens, + prefix="val", + dec_context_size=dec_context_size, + ) ) + for layer_idx in self.transcript_decoder_layers: + cross_attention_probs = [attn_info[layer_idx]['cross_attn_probabilities'][0]] + wandb_log_dict.update( + self.log_attention_probs( + cross_attention_probs, + audio_codes_lens_target, + text_lens, + prefix=f"val/layer_{layer_idx}", + dec_context_size=dec_context_size, + ) + ) + + if batch_output['aligner_attn_soft'] is not None: + wandb_log_dict.update( + self.log_attention_probs( + [batch_output['aligner_attn_soft']], + audio_codes_lens_target, + text_lens, + prefix="val/aligner_encoder_attn", + ) + ) + + if batch_output['aligner_attn_hard'] is not None: + wandb_log_dict.update( + self.log_attention_probs( + [batch_output['aligner_attn_hard'].unsqueeze(1)], + audio_codes_lens_target, + text_lens, + prefix="val/aligner_encoder_attn_hard", + ) + ) + + # Perform single wandb log call if wandb is active and there is data + for logger in self.loggers: + if isinstance(logger, WandbLogger) and wandb_log_dict: + logger.experiment.log(wandb_log_dict) + + local_transformer_loss = batch_output['local_transformer_loss'] val_output = { 'val_loss': loss, 'val_codebook_loss': codebook_loss, 'val_alignment_loss': alignment_loss, + 'val_local_transformer_loss': local_transformer_loss, + 'val_aligner_encoder_loss': aligner_encoder_loss, } self.validation_step_outputs.append(val_output) return val_output - def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, use_cfg=False, cfg_scale=1.0): + def get_cross_attention_scores(self, attn_probs, filter_layers=None): + """ + Returns the cross attention probabilities for the last audio timestep + """ + mean_cross_attn_scores = [] + all_heads_cross_attn_scores = [] + for lidx, layerwise_attn_prob in enumerate(attn_probs): + if (filter_layers is not None and lidx not in filter_layers) or ( + lidx not in self.transcript_decoder_layers + ): + continue + cross_attn_prob = layerwise_attn_prob['cross_attn_probabilities'][ + 0 + ] # B, H, audio_timesteps, text_timesteps + mean_cross_attn_scores.append(cross_attn_prob.mean(dim=1)) # B, audio_timesteps, text_timesteps + for head_idx in range(cross_attn_prob.size(1)): + all_heads_cross_attn_scores.append(cross_attn_prob[:, head_idx, -1, :]) # B, text_timesteps + + mean_cross_attn_scores = torch.stack(mean_cross_attn_scores, dim=1) # B, L, audio_timesteps, text_timesteps + mean_cross_attn_scores = mean_cross_attn_scores.mean(dim=1) # B, audio_timesteps, text_timesteps + last_audio_timestep_scores = mean_cross_attn_scores[:, -1, :] # B, text_timesteps + return last_audio_timestep_scores, all_heads_cross_attn_scores + + def get_most_attended_text_timestep( + self, + alignment_attention_scores, + last_attended_timesteps, + text_lens, + lookahead_window_size, + attended_timestep_counter, + batch_size, + ): + """ + Returns the most attended timestep for each batch item + """ + text_time_step_attended = [] + for bidx in range(batch_size): + last_attended_timestep = last_attended_timesteps[-1][bidx] + if attended_timestep_counter[bidx].get(last_attended_timestep, 0) >= 8: + # This is probably an attention sink! Move to the next timestep + last_attended_timestep += 1 + window_size = lookahead_window_size + window_end = min(last_attended_timestep + window_size, text_lens[bidx] - 3) # Ignore the last 3 timesteps + item_attention_scores = alignment_attention_scores[bidx, last_attended_timestep:window_end] + if item_attention_scores.size(0) == 0: + # This means the sentence has ended + attended_timestep = text_lens[bidx].item() - 1 + else: + attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep + text_time_step_attended.append(attended_timestep) + attended_timestep_counter[bidx][attended_timestep] = ( + attended_timestep_counter[bidx].get(attended_timestep, 0) + 1 + ) + return text_time_step_attended, attended_timestep_counter + + def construct_inference_prior( + self, + prior_epsilon, + cross_attention_scores, + text_lens, + text_time_step_attended, + attended_timestep_counter, + unfinished_texts, + finished_texts_counter, + end_indices, + lookahead_window_size, + batch_size, + ): + # Attn prior for the next timestep + _attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + prior_epsilon + _attn_prior = _attn_prior.to(cross_attention_scores.device) + for bidx in range(cross_attention_scores.shape[0]): + if bidx < batch_size: + _text_len = text_lens[bidx] + if text_lens[bidx] <= 5: + # Very short sentences, No Prior + _attn_prior[bidx, 0, :] = 1.0 + else: + _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx] - 1)] = ( + 1.0 # Slight exposure to history for better pronounciation. Not very important. + ) + _attn_prior[bidx, 0, text_time_step_attended[bidx]] = ( + 1.0 # Slightly bias to continue moving forward. Not very important. + ) + for ind in range(1, lookahead_window_size + 1): + _attn_prior[bidx, 0, min(text_time_step_attended[bidx] + ind, _text_len - 1)] = 1.0 + + # Penalize timesteps that have been attended to more than 10 times + for _timestep in attended_timestep_counter[bidx]: + if attended_timestep_counter[bidx][_timestep] >= 10: + # This means the timestep has been attended to more than 10 times (To avoid getting stuck) + _attn_prior[bidx, 0, : _timestep + 1] = prior_epsilon + + unfinished_texts[bidx] = False + if text_time_step_attended[bidx] < text_lens[bidx] - 3: + # This means the sentence has not ended + if bidx not in end_indices: + unfinished_texts[bidx] = True + + if text_time_step_attended[bidx] >= text_lens[bidx] - 2 or bidx in end_indices: + if bidx not in finished_texts_counter: + finished_texts_counter[bidx] = 0 + + for bidx in finished_texts_counter: + finished_texts_counter[bidx] += 1 + if finished_texts_counter[bidx] > 5: + # This means we have been within the text EOS window for at least 5 timesteps + # We should allow EOS to be predicted now. + unfinished_texts[bidx] = False + + return _attn_prior, unfinished_texts, finished_texts_counter + + def get_inference_attention_plots( + self, + cross_attention_scores_all_timesteps, + all_heads_cross_attn_scores_all_timesteps, + text_lens, + predicted_codes_lens, + batch_size, + compute_all_heads_attn_maps, + last_attended_timestep, + ): + last_attended_timestep = np.array(last_attended_timestep).T + cross_attention_scores_all_timesteps = torch.stack( + cross_attention_scores_all_timesteps, dim=2 + ) # B, text_timesteps, T' + headwise_cross_attention_scores_all_timesteps = [] + for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): + head_cross_attention_all_timesteps = torch.stack( + [x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2 + ) # B, text_timesteps, T' + headwise_cross_attention_scores_all_timesteps.append(head_cross_attention_all_timesteps) + + cross_attention_maps = [] + headwise_cross_attention_maps = [] + for bidx in range(batch_size): + item_cross_attention_scores = cross_attention_scores_all_timesteps[ + bidx, : text_lens[bidx], : predicted_codes_lens[bidx] + ] + cross_attn_np = plot_alignment_to_numpy( + item_cross_attention_scores.cpu().numpy(), + attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]], + ) + cross_attention_maps.append(cross_attn_np) + item_all_head_cross_attn_maps = [] + if compute_all_heads_attn_maps: + for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): + item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][ + bidx, : text_lens[bidx], : predicted_codes_lens[bidx] + ] + headwise_cross_attn_np = plot_alignment_to_numpy( + item_headwise_cross_attention_scores.cpu().numpy(), + attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]], + ) + item_all_head_cross_attn_maps.append(headwise_cross_attn_np) + headwise_cross_attention_maps.append(item_all_head_cross_attn_maps) + + return cross_attention_maps, headwise_cross_attention_maps + + def find_eos_frame_index(self, codes, eos_detection_method) -> Union[int, float]: + """ + Checks for EOS in the predicted codes. Returns the index of the first frame within the frame stack + that contains an EOS token across any codebook, or `None` if no EOS is found. + Args: + codes: (num_codebooks, frame_stacking_factor) + Returns: + index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found + """ + eos_mask = codes == self.audio_eos_id # (codebooks, frame_stacking_factor) + detection_type = EOSDetectionMethod.detection_type(eos_detection_method) + if detection_type == "any": + eos_per_frame = eos_mask.any( + dim=0 + ) # (frame_stacking_factor,) - True if any codebook has EOS in this frame + elif detection_type == "all": + eos_per_frame = eos_mask.all( + dim=0 + ) # (frame_stacking_factor,) - True if all codebooks have EOS in this frame + elif detection_type == "zero_cb": + eos_per_frame = eos_mask[:1, :].any( + dim=0 + ) # (frame_stacking_factor,) - True if zeroth codebook has EOS in this frame + else: + raise ValueError(f"Invalid EOS detection method: {eos_detection_method}") + # find first frame with EOS + if eos_per_frame.any(): + # return index of the first frame with EOS + return eos_per_frame.nonzero()[0].item() + return float('inf') + + def detect_eos(self, audio_codes_multinomial, audio_codes_argmax, eos_detection_method) -> Union[int, float]: + """ + Detects EOS in the predicted codes. Returns the index of the first frame within the frame stack + that triggers EOS detection, or `float('inf')` if no EOS is found. + Args: + audio_codes_multinomial: (num_codebooks, frame_stacking_factor) - Multinomial samples + audio_codes_argmax: (num_codebooks, frame_stacking_factor) - Argmax samples + eos_detection_method: EOS detection method + Returns: + index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found + """ + sampling_type = EOSDetectionMethod.sampling_type(eos_detection_method) + if sampling_type == "argmax": + return self.find_eos_frame_index(audio_codes_argmax, eos_detection_method) + elif sampling_type == "argmax_or_multinomial": + argmax_eos_frame = self.find_eos_frame_index(audio_codes_argmax, eos_detection_method) + multinomial_eos_frame = self.find_eos_frame_index(audio_codes_multinomial, eos_detection_method) + return min(argmax_eos_frame, multinomial_eos_frame) + else: + raise ValueError(f"Invalid EOS detection method: {eos_detection_method}") + + def infer_batch( + self, + batch, + max_decoder_steps=500, + temperature=0.7, + topk=80, + use_cfg=False, + cfg_scale=1.0, + return_cross_attn_probs=False, + apply_attention_prior=False, + prior_epsilon=1e-5, + lookahead_window_size=10, + estimate_alignment_from_layers=None, + apply_prior_to_layers=None, + start_prior_after_n_audio_steps=10, + compute_all_heads_attn_maps=False, + use_local_transformer_for_inference=False, + use_LT_kv_cache=True, + maskgit_n_steps=3, + maskgit_noise_scale=0.0, + maskgit_fixed_schedule=None, + maskgit_dynamic_cfg_scale=False, + maskgit_sampling_type=None, + ignore_finished_sentence_tracking=False, + eos_detection_method="argmax_or_multinomial_any", + # Setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4 + # lines up with the codec's minimum frame requirement. + min_generated_frames=4, + ): + eos_detection_method = EOSDetectionMethod(eos_detection_method) with torch.no_grad(): + start_time = time.time() self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference) context_tensors = self.prepare_context_tensors(batch) text = context_tensors['text'] audio_codes_bos = torch.full( - (text.size(0), self.cfg.num_audio_codebooks, 1), self.audio_bos_id, device=text.device + (text.size(0), self.num_audio_codebooks, self.frame_stacking_factor), + self.audio_bos_id, + device=text.device, ).long() - audio_codes_lens = torch.full((text.size(0),), 1, device=text.device).long() + audio_codes_lens = torch.full( + (text.size(0),), 1, device=text.device + ).long() # intetionally 1 rather than self.frame_stacking_factor since this is in stacked form audio_codes_input = audio_codes_bos audio_codes_mask = get_mask_from_lengths(audio_codes_lens) @@ -833,11 +2263,23 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us context_tensors['cond'], context_tensors['cond_mask'], context_tensors['additional_decoder_input'], - context_tensors['addtional_decoder_mask'], + context_tensors['additional_decoder_mask'], ) ) - for idx in range(max_decoder_steps): + cross_attention_scores_all_timesteps = [] + all_heads_cross_attn_scores_all_timesteps = [] + _attn_prior = None + unfinished_texts = {} + finished_texts_counter = {} + attended_timestep_counter = [{} for _ in range(text.size(0))] + last_attended_timesteps = [ + [1 for _ in range(text.size(0))] + ] # Maintain a list of attended timesteps as we predict audio for each batch item + time_to_first_prediction = 0.0 + for idx in range(max_decoder_steps // self.frame_stacking_factor): + if idx == 1: + time_to_first_prediction = time.time() - start_time if idx % 20 == 0: print(f"Decoding timestep {idx}") audio_codes_embedded = self.embed_audio_tokens(audio_codes_input) @@ -845,14 +2287,25 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us _audio_codes_embedded = torch.cat( [context_tensors['additional_decoder_input'], audio_codes_embedded], dim=1 ) - _audio_codes_mask = torch.cat([context_tensors['addtional_decoder_mask'], audio_codes_mask], dim=1) + _audio_codes_mask = torch.cat( + [context_tensors['additional_decoder_mask'], audio_codes_mask], dim=1 + ) else: _audio_codes_embedded = audio_codes_embedded _audio_codes_mask = audio_codes_mask + if apply_prior_to_layers is not None: + attn_prior = [None for _ in range(self.decoder.n_layers)] + for layer_idx in apply_prior_to_layers: + attn_prior[layer_idx] = _attn_prior + else: + attn_prior = _attn_prior + + if self.model_type == 'multi_encoder_context_tts': + attn_prior = [attn_prior, None] + if use_cfg: batch_size = audio_codes_embedded.size(0) - # Combine conditional and unconditional inputs into one batch if isinstance(context_tensors['cond'], list): cfg_cond = [ torch.cat([cond_item, dummy_cond_item], dim=0) @@ -877,12 +2330,16 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us dummy_addition_dec_mask ) - combined_logits, _ = self.forward( + # print(f"step {idx}") + # print(f"use_cfg {use_cfg}") + # print(f"shape {cfg_audio_codes_embedded.shape}") + # print(f"use kv cahce? {self.use_kv_cache_for_inference}") + combined_logits, attn_probs, dec_out = self.forward( dec_input_embedded=cfg_audio_codes_embedded, dec_input_mask=cfg_audio_codes_mask, cond=cfg_cond, cond_mask=cfg_cond_mask, - attn_prior=None, + attn_prior=attn_prior, multi_encoder_mapping=context_tensors['multi_encoder_mapping'], ) @@ -890,48 +2347,185 @@ def infer_batch(self, batch, max_decoder_steps=500, temperature=0.7, topk=80, us uncond_logits = combined_logits[batch_size:] all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits else: - all_code_logits, _ = self.forward( + batch_size = audio_codes_embedded.size(0) + all_code_logits, attn_probs, dec_out = self.forward( dec_input_embedded=_audio_codes_embedded, dec_input_mask=_audio_codes_mask, cond=context_tensors['cond'], cond_mask=context_tensors['cond_mask'], - attn_prior=None, + attn_prior=attn_prior, multi_encoder_mapping=context_tensors['multi_encoder_mapping'], ) + + if return_cross_attn_probs or apply_attention_prior: + cross_attention_scores, all_heads_cross_attn_scores = self.get_cross_attention_scores( + attn_probs + ) # B, text_timesteps + alignment_attention_scores = cross_attention_scores + if estimate_alignment_from_layers is not None: + alignment_attention_scores, _ = self.get_cross_attention_scores( + attn_probs, filter_layers=estimate_alignment_from_layers + ) # B, text_timesteps + + cross_attention_scores_all_timesteps.append(cross_attention_scores) + all_heads_cross_attn_scores_all_timesteps.append(all_heads_cross_attn_scores) + + if apply_attention_prior and idx >= start_prior_after_n_audio_steps: + text_time_step_attended, attended_timestep_counter = self.get_most_attended_text_timestep( + alignment_attention_scores=alignment_attention_scores, + last_attended_timesteps=last_attended_timesteps, + text_lens=context_tensors['text_lens'], + lookahead_window_size=lookahead_window_size, + attended_timestep_counter=attended_timestep_counter, + batch_size=batch_size, + ) + last_attended_timesteps.append(text_time_step_attended) + _attn_prior, unfinished_texts, finished_texts_counter = self.construct_inference_prior( + prior_epsilon=prior_epsilon, + cross_attention_scores=cross_attention_scores, + text_lens=context_tensors['text_lens'], + text_time_step_attended=text_time_step_attended, + attended_timestep_counter=attended_timestep_counter, + unfinished_texts=unfinished_texts, + finished_texts_counter=finished_texts_counter, + end_indices=end_indices, + lookahead_window_size=lookahead_window_size, + batch_size=batch_size, + ) + + if ignore_finished_sentence_tracking: + finished_items = {} + unfinished_items = {} + else: + finished_items = { + k: v for k, v in finished_texts_counter.items() if v >= 20 + } # Items that have been close to the end for atleast 20 timesteps + unfinished_items = {k: v for k, v in unfinished_texts.items() if v} + + # Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor) + # This guards against rare cases of termination right at the start of generation. + forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames + all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) - audio_codes_next = self.sample_codes_from_logits( - all_code_logits_t, temperature=temperature, topk=topk - ) # (B, num_codebooks) + if use_local_transformer_for_inference: + if self.local_transformer_type == LocalTransformerType.AR: + # Autoregressive sampling with local transformer + audio_codes_next = self.local_transformer_sample_autoregressive( + dec_output=dec_out[:, -1, :], + temperature=temperature, + topk=topk, + unfinished_items=unfinished_items, + finished_items=finished_items, + use_cfg=use_cfg, + cfg_scale=cfg_scale, + use_kv_cache=use_LT_kv_cache, + forbid_audio_eos=forbid_audio_eos, + ) + elif self.local_transformer_type == LocalTransformerType.MASKGIT: + audio_codes_next = self.local_transformer_sample_maskgit( + dec_output=dec_out[:, -1, :], + temperature=temperature, + topk=topk, + unfinished_items=unfinished_items, + finished_items=finished_items, + use_cfg=use_cfg, + cfg_scale=cfg_scale, + n_steps=maskgit_n_steps, + noise_scale=maskgit_noise_scale, + fixed_schedule=maskgit_fixed_schedule, + dynamic_cfg_scale=maskgit_dynamic_cfg_scale, + sampling_type=maskgit_sampling_type, + forbid_audio_eos=forbid_audio_eos, + ) + else: + raise ValueError( + f"Local transformer inference requested by but local transformer type is {self.local_transformer_type}" + ) + else: + # Parallel sampling from all codebooks + audio_codes_next = self.sample_codes_from_logits( + all_code_logits_t, + temperature=temperature, + topk=topk, + unfinished_items=unfinished_items, + finished_items=finished_items, + forbid_audio_eos=forbid_audio_eos, + ) # (B, num_codebooks, frame_stacking_factor) all_codes_next_argmax = self.sample_codes_from_logits( - all_code_logits_t, temperature=0.01 - ) # (B, num_codebooks) + all_code_logits_t, + temperature=0.01, + topk=1, + unfinished_items=unfinished_items, + finished_items=finished_items, + forbid_audio_eos=forbid_audio_eos, + ) # (B, num_codebooks, frame_stacking_factor) for item_idx in range(all_codes_next_argmax.size(0)): if item_idx not in end_indices: - pred_token = all_codes_next_argmax[item_idx][0].item() - pred_token_multinomial = audio_codes_next[item_idx][0].item() - if (pred_token == self.audio_eos_id) or (pred_token_multinomial == self.audio_eos_id): - print("End detected for item {} at timestep {}".format(item_idx, idx)) - end_indices[item_idx] = idx + end_frame_index = self.detect_eos( + audio_codes_next[item_idx], all_codes_next_argmax[item_idx], eos_detection_method + ) + if end_frame_index != float('inf'): + global_index = idx * self.frame_stacking_factor + end_frame_index + end_indices[item_idx] = global_index + print(f"End detected for item {item_idx} at decoder timestep: {idx}") all_predictions.append(audio_codes_next) - audio_codes_input = torch.cat( - [audio_codes_input, audio_codes_next.unsqueeze(-1)], dim=-1 - ) # (B, C, T') - audio_codes_lens = audio_codes_lens + 1 + audio_codes_input = torch.cat([audio_codes_input, audio_codes_next], dim=-1) # (B, C, T') + audio_codes_lens = audio_codes_lens + 1 # already in stacked form audio_codes_mask = get_mask_from_lengths(audio_codes_lens) - if len(end_indices) == text.size(0): + if len(end_indices) == text.size(0) and len(all_predictions) >= 4: + # Codec must be of atleast 4 timesteps to be decoded properly print("All ends reached") break - - predicted_codes = torch.stack(all_predictions, dim=-1) # (B, num_codebooks, T') - predicted_lens = [end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0))] + tts_generation_time = time.time() - start_time + tts_generation_time_per_frame = tts_generation_time / (len(all_predictions) * self.frame_stacking_factor) + + # Concatenate the list of predictions along the time dimension. Note that when frame stacking is on, + # this also undoes the stacking. + predicted_codes = torch.cat(all_predictions, dim=-1) # (B, num_codebooks, T') + predicted_lens = [ + end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0)) + ] # Ensure that the codec is atleast of length 4 predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long() predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens) - + end_time = time.time() + total_audio_duration_generated = ( + predicted_audio_lens.max().item() * predicted_audio_lens.shape[0] + ) / self.sample_rate + rtf = total_audio_duration_generated / (end_time - start_time) + rtf_metrics = { + 'rtf': rtf, + 'time_to_first_prediction': time_to_first_prediction, + 'tts_generation_time': tts_generation_time, + 'max_frames_generated': len(all_predictions), + 'tts_generation_time_per_frame': tts_generation_time_per_frame, + 'batch_size': text.size(0), + } torch.cuda.empty_cache() - return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens + if return_cross_attn_probs: + cross_attention_maps, headwise_cross_attention_maps = self.get_inference_attention_plots( + cross_attention_scores_all_timesteps, + all_heads_cross_attn_scores_all_timesteps, + context_tensors['text_lens'], + predicted_codes_lens, + text.size(0), + compute_all_heads_attn_maps, + last_attended_timesteps, + ) + return ( + predicted_audio, + predicted_audio_lens, + predicted_codes, + predicted_codes_lens, + rtf_metrics, + cross_attention_maps, + headwise_cross_attention_maps, + ) + else: + # For backward compatibility + return predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, rtf_metrics def test_step(self, batch, batch_idx): with torch.no_grad(): @@ -940,7 +2534,7 @@ def test_step(self, batch, batch_idx): topk = self.cfg.get('inference_topk', 80) use_cfg = self.cfg.get('inference_use_cfg', False) cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) - predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens = self.infer_batch( + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( batch, max_decoder_steps=self.cfg.get('max_decoder_steps', 500), temperature=temperature, @@ -948,619 +2542,198 @@ def test_step(self, batch, batch_idx): use_cfg=use_cfg, cfg_scale=cfg_scale, ) - for idx in range(predicted_audio.size(0)): - predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() - predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] - item_idx = batch_idx * test_dl_batch_size + idx - self.tb_logger.add_audio( - 'predicted_audio', - predicted_audio_np, - global_step=item_idx, - sample_rate=self.cfg.sample_rate, - ) - # Save the predicted audio - log_dir = self.logger.log_dir - audio_dir = os.path.join(log_dir, 'audios') - if not os.path.exists(audio_dir): - os.makedirs(audio_dir) - audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') - sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) + + for logger in self.loggers: + is_wandb = isinstance(logger, WandbLogger) + is_tb = isinstance(logger, TensorBoardLogger) + if not is_wandb and not is_tb: + raise ValueError( + "Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." + ) + + for idx in range(predicted_audio.size(0)): + predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() + predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] + item_idx = batch_idx * test_dl_batch_size + idx + + if is_wandb: + log_dict = { + "test/predicted_audio": wandb.Audio( + predicted_audio_np, sample_rate=self.sample_rate, caption="Predicted Audio" + ), + } + logger.experiment.log(log_dict, step=item_idx) + + if is_tb: + logger.experiment.add_audio( + 'test/predicted_audio', + predicted_audio_np, + global_step=item_idx, + sample_rate=self.sample_rate, + ) + + # Save the predicted audio + log_dir = logger.log_dir + audio_dir = os.path.join(log_dir, 'audios') + if not os.path.exists(audio_dir): + os.makedirs(audio_dir) + audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') + sf.write(audio_path, predicted_audio_np, self.sample_rate) def on_validation_epoch_end(self): collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean() val_loss = collect("val_loss") val_codebook_loss = collect("val_codebook_loss") val_alignment_loss = collect("val_alignment_loss") - self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) - self.log("val_codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) - self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) + val_aligner_encoder_loss = collect("val_aligner_encoder_loss") + # log val_loss in the same group as the other val metrics. + self.log("val/loss", val_loss, prog_bar=True, sync_dist=True) + # ensure val_loss is available for epoch-level checkpointing and filename generation without cluttering wandb logs. + self.log( + "val_loss", + val_loss, + prog_bar=False, + sync_dist=True, + on_step=False, + on_epoch=True, + logger=False, + enable_graph=False, + ) + self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) + self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) + self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True) + if self.local_transformer_type != LocalTransformerType.NO_LT: + val_local_transformer_loss = collect("val_local_transformer_loss") + self.log("val/local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True) self.validation_step_outputs.clear() # free memory - def get_dataset(self, cfg, dataset_type): + def get_dataset(self, dataset_cfg, dataset_type): dataset = instantiate( - cfg.dataset, + dataset_cfg.dataset, + sample_rate=self.sample_rate, bos_id=self.bos_id, eos_id=self.eos_id, audio_bos_id=self.audio_bos_id, audio_eos_id=self.audio_eos_id, context_audio_bos_id=self.context_audio_bos_id, context_audio_eos_id=self.context_audio_eos_id, - num_audio_codebooks=self.cfg.num_audio_codebooks, - codec_model_downsample_factor=self.cfg.codec_model_downsample_factor, + num_audio_codebooks=self.data_num_audio_codebooks, + codec_model_samples_per_frame=self.codec_model_samples_per_frame, prior_scaling_factor=self.cfg.prior_scaling_factor, load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, dataset_type=dataset_type, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder, + text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, context_duration_min=self.cfg.context_duration_min, context_duration_max=self.cfg.context_duration_max, + text_context_remapping=self.text_context_remapping, + text_context_remapping_prob=self.text_context_remapping_prob, ) - dataset.load_16khz_audio = self.model_type == 'single_encoder_sv_tts' + dataset.load_16khz_audio = False dataset.tokenizer_config = ( self.cfg.text_tokenizers ) # This will be used in worker_init_fn for instantiating tokenizer return dataset - def _setup_train_dataloader(self, cfg): - dataset = self.get_dataset(cfg, dataset_type='train') - sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) - persistent_workers = True - if cfg.dataloader_params.num_workers == 0: - persistent_workers = False - # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) - dataset.text_tokenizer, dataset.text_conditioning_tokenizer = self._setup_tokenizers(self.cfg) - data_loader = torch.utils.data.DataLoader( - dataset, - collate_fn=dataset.collate_fn, - sampler=sampler, - **cfg.dataloader_params, - worker_init_fn=worker_init_fn, - persistent_workers=persistent_workers, + def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.DataLoader: + # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also + # cfg is a classifier-free guidance. + dataset = MagpieTTSLhotseDataset( + sample_rate=self.sample_rate, + volume_norm=dataset_cfg.volume_norm, + codec_model_samples_per_frame=self.codec_model_samples_per_frame, + audio_bos_id=self.audio_bos_id, + audio_eos_id=self.audio_eos_id, + context_audio_bos_id=self.context_audio_bos_id, + context_audio_eos_id=self.context_audio_eos_id, + num_audio_codebooks=self.data_num_audio_codebooks, + prior_scaling_factor=self.cfg.prior_scaling_factor, + load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, + dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) + load_16khz_audio=False, + pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, + context_duration_min=self.cfg.context_duration_min, + context_duration_max=self.cfg.context_duration_max, + use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder, + text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, + tokenizer_config=self.cfg.text_tokenizers, + text_context_remapping=self.text_context_remapping, + text_context_remapping_prob=self.text_context_remapping_prob, ) - return data_loader - - def _setup_test_dataloader(self, cfg): - dataset = self.get_dataset(cfg, dataset_type='test') - persistent_workers = True - if cfg.dataloader_params.num_workers == 0: - persistent_workers = False - # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) - dataset.text_tokenizer, dataset.text_conditioning_tokenizer = self._setup_tokenizers(self.cfg, mode='test') - - data_loader = torch.utils.data.DataLoader( - dataset, - collate_fn=dataset.collate_fn, - **cfg.dataloader_params, - worker_init_fn=worker_init_fn, - persistent_workers=persistent_workers, + data_loader = get_lhotse_dataloader_from_config( + config=dataset_cfg.dataset, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=dataset, ) return data_loader - def setup_training_data(self, cfg): - self._train_dl = self._setup_train_dataloader(cfg) - - def setup_validation_data(self, cfg): - self._validation_dl = self._setup_test_dataloader(cfg) - - def setup_test_data(self, cfg): - self._test_dl = self._setup_test_dataloader(cfg) - - @classmethod - def list_available_models(cls) -> List[PretrainedModelInfo]: - return [] - - -class MagpieTTS_ModelInference(MagpieTTS_Model): - """Small override of MagpieTTS_Model for parallel multi-GPU inference and metrics calculation. - This class is used in 'test' mode and leverages trainer.test() for multi-GPU/multi-node inference. - Saves the predicted audio files and logs the CER/WER metrics as individual json files for each audio. - """ - - def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): - super().__init__(cfg, trainer) - if cfg.get('pref_set_language', "en") == "en": - self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( - model_name="nvidia/parakeet-tdt-1.1b" - ) - self.eval_asr_model.freeze() - self.eval_asr_model.eval() - - self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( - model_name='titanet_large' - ) - self.eval_speaker_verification_model.freeze() - self.eval_speaker_verification_model.eval() - - if cfg.get('load_whisper_model', False): - from transformers import WhisperForConditionalGeneration, WhisperProcessor - - self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") - self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") - self.whisper_model.eval() - - def transcribe_with_whisper(self, audio_filepath, language): - speech_array, sampling_rate = librosa.load(audio_filepath, sr=16000) - forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language=language) if language else None - inputs = self.whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features - inputs = inputs.to(self.device) - with torch.no_grad(): - predicted_ids = self.whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) - transcription = self.whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) - result = transcription[0] - return result - - def process_text(self, input_text): - """ - Normalizes text for CER/WER calculation. - Taken from hallucination_eval.py - """ - # Convert text to lowercase - lower_case_text = input_text.lower() - - # Remove commas from text - no_comma_text = lower_case_text.replace(",", "") - # Replace "-" with spaces - no_dash_text = no_comma_text.replace("-", " ") - no_dash_text = no_dash_text.replace("'", "") - no_dash_text = no_dash_text.replace(";", "") - no_dash_text = no_dash_text.replace(".", "") - - # Replace double spaces with single space - single_space_text = " ".join(no_dash_text.split()) - - single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) - - # @shehzeen: Added this to handle some common errors in ASR transcripts - single_space_text.replace("h t t p", "http") - single_space_text.replace("w w w", "www") - - return single_space_text - - def get_speaker_embeddings_from_filepaths(self, filepaths): - audio_batch = [] - audio_lengths = [] - for filepath in filepaths: - audio, sr = sf.read(filepath) - if sr != 16000: - audio = librosa.core.resample(audio, orig_sr=sr, target_sr=16000) - audio_tensor = torch.tensor(audio, dtype=torch.float32, device=self.device) - audio_batch.append(audio_tensor) - audio_lengths.append(audio_tensor.size(0)) - - batch_audio_lens = torch.tensor(audio_lengths, device=self.device).long() - max_audio_len = int(batch_audio_lens.max().item()) - audio_batch = stack_tensors(audio_batch, max_lens=[max_audio_len]) - - _, speaker_embeddings = self.eval_speaker_verification_model.forward( - input_signal=audio_batch, input_signal_length=batch_audio_lens - ) - - return speaker_embeddings - - def test_step(self, batch, batch_idx): - with torch.no_grad(): - test_dl_batch_size = self._test_dl.batch_size - temperature = self.cfg.get('inference_temperature', 0.7) - topk = self.cfg.get('inference_topk', 80) - use_cfg = self.cfg.get('inference_use_cfg', False) - cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) - predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens = self.infer_batch( - batch, - max_decoder_steps=self.cfg.get('max_decoder_steps', 500), - temperature=temperature, - topk=topk, - use_cfg=use_cfg, - cfg_scale=cfg_scale, - ) - predicted_audio_paths = [] - audio_durations = [] - batch_invalid = False - for idx in range(predicted_audio.size(0)): - predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() - predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] - item_idx = batch_idx * test_dl_batch_size + idx - # Save the predicted audio - log_dir = self.logger.log_dir - audio_dir = os.path.join(log_dir, 'audios') - if not os.path.exists(audio_dir): - os.makedirs(audio_dir) - audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') - audio_durations.append(len(predicted_audio_np) / self.cfg.sample_rate) - sf.write(audio_path, predicted_audio_np, self.cfg.sample_rate) - - predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) - predicted_codes_torch = predicted_codes_torch[:, : predicted_codes_lens[idx]] - torch.save( - predicted_codes_torch, - os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt'), - ) - predicted_audio_paths.append(audio_path) - - if not batch_invalid: - with torch.no_grad(): - try: - if self.cfg.get("pref_set_language", "en") == "en": - pred_transcripts = self.eval_asr_model.transcribe( - predicted_audio_paths, batch_size=len(predicted_audio_paths) - )[0] - pred_transcripts = [self.process_text(transcript) for transcript in pred_transcripts] - else: - pred_transcripts = [ - self.transcribe_with_whisper(audio_path, self.cfg.pref_set_language) - for audio_path in predicted_audio_paths - ] - pred_transcripts = [self.process_text(transcript) for transcript in pred_transcripts] - except Exception as e: - assert ( - predicted_audio_lens[idx] < 1000 - ).any(), f"Expected short audio file to be the only cause of ASR errors, but got error with lengths {predicted_audio_lens}" - logging.warning(f"Exception during ASR transcription: {e}") - logging.warning( - "Skipping processing of the batch; generating metrics indicating a WER of 100% and " - "Speaker Similarity of 0.0" - ) - batch_invalid = True - continue # don't break since we want to continue building audio durations list - pred_speaker_embeddings = self.get_speaker_embeddings_from_filepaths(predicted_audio_paths) - gt_speaker_embeddings = self.get_speaker_embeddings_from_filepaths(batch['audio_filepaths']) - - for idx in range(predicted_audio.size(0)): - if not batch_invalid: - item_idx = batch_idx * test_dl_batch_size + idx - pred_transcript = pred_transcripts[idx] - gt_transcript = self.process_text(batch['raw_texts'][idx]) - - cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True) - wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False) - - spk_embedding_pred = pred_speaker_embeddings[idx].cpu().numpy() - spk_embedding_gt = gt_speaker_embeddings[idx].cpu().numpy() - - spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( - np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) - ) - else: - # Create an entry indicating invalid metrics - cer_gt = 1.0 - wer_gt = 1.0 - spk_similarity = 0.0 - pred_transcript = "" # do not change this string; subsequent processing relies on it - gt_transcript = self.process_text(batch['raw_texts'][idx]) - - item_metrics = { - 'cer_gt': float(cer_gt), - 'wer_gt': float(wer_gt), - 'duration': audio_durations[idx], - 'spk_similarity': float(spk_similarity), - 'pred_transcript': pred_transcript, - 'gt_transcript': gt_transcript, - } - - with open( - os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w' - ) as f: - json.dump(item_metrics, f) + def setup_training_data(self, dataset_cfg): + if dataset_cfg.get("use_lhotse", False): + # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also + # cfg is a classifier-free guidance. + # specify target sampling rate the same as codec model's because lhotse config defaults 16_000. + if not isinstance(dataset_cfg, DictConfig): + dataset_cfg = OmegaConf.create(dataset_cfg) + OmegaConf.set_struct(dataset_cfg.dataset, False) + dataset_cfg.dataset.update({"sample_rate": self.sample_rate}) + OmegaConf.set_struct(dataset_cfg.dataset, True) -class MagpieTTS_ModelDPO(MagpieTTS_Model): - """Extends MagpieTTS_Model to support Direct Preference Optimization (DPO) training. - This class is used for training the model with preference-based losses, including DPO, RPO, and IPO losses. - It maintains a frozen reference model to compare log probabilities between policy and reference outputs. - - """ - - def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): - """Initialize the MagpieTTS_ModelDPO class. - - Args: - cfg (DictConfig): Configuration object containing model hyperparameters. - trainer (Trainer, optional): Trainer instance for model training. - """ - super().__init__(cfg, trainer) - # Create a copy of the configuration for the reference model - ref_model_cfg = copy.deepcopy(cfg) - with open_dict(ref_model_cfg): - ref_model_cfg.train_ds = None - ref_model_cfg.validation_ds = None - - # Initialize the frozen reference model - self._reference_model = MagpieTTS_Model(cfg=ref_model_cfg) - print("Loading reference model from checkpoint") - self._reference_model.load_state_dict( - torch.load(cfg.reference_model_ckpt_path, map_location="cpu")['state_dict'] - ) - self.freeze_model(self._reference_model) - self._reference_model.eval() - self._reference_model._no_state_dict = True - print("Reference model loaded and frozen") - - def state_dict(self, destination=None, prefix='', keep_vars=False): - """Return the state dictionary excluding non-trainable components. - - Excludes state keys related to `_speaker_verification_model`, `_codec_model`, and `_reference_model`. - - Args: - destination (dict, optional): The destination dictionary for the state_dict. - prefix (str, optional): Prefix to prepend to keys. - keep_vars (bool, optional): If True, tensors in the returned dictionary will not be detached. - - Returns: - dict: Filtered state dictionary. - """ - state_dict = super().state_dict(destination, prefix, keep_vars) - keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model'] - for key in list(state_dict.keys()): - if any([substring in key for substring in keys_substrings_to_exclude]): - del state_dict[key] - return state_dict - - def _get_batch_logps(self, logits, labels, loss_mask, average_log_prob=False): - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. - Shape: (batch_size, sequence_length) - average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return - the sum of the log probabilities of the (non-masked) tokens. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under - the given logits. - """ - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + self._train_dl = self.get_lhotse_dataloader(dataset_cfg, mode='train') else: - return (per_token_logps * loss_mask).sum(-1) - - def preference_loss( - self, - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - chosen_gt_rewards=None, - rejected_gt_rewards=None, - beta=0.2, - gt_reward_scale=1.0, - label_smoothing=0, - loss_type="dpo", - reference_free=False, - ): - """Compute the DPO loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: Log probabilities of the policy model for the chosen responses. - Shape: (batch_size,) - policy_rejected_logps: Log probabilities of the policy model for the rejected responses. - Shape: (batch_size,) - reference_chosen_logps: Log probabilities of the reference model for the chosen responses. - Shape: (batch_size,) - reference_rejected_logps: Log probabilities of the reference model for the rejected responses. - Shape: (batch_size,) - beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore - the reference model as beta -> 0. - label_smoothing: conservativeness for DPO loss, which assumes that preferences are noisy (flipped with - probability label_smoothing) - ipo: If True, use the IPO loss instead of the DPO loss. - reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model - that assigns equal probability to all responses. - - Returns: - A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). - The losses tensor contains the DPO loss for each example in the batch. - The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected - responses, respectively. - """ - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - if reference_free: - ref_logratios = 0 - - logits = pi_logratios - ref_logratios # also known as h_{\pi_\theta}^{y_w,y_l} - # logits = (policy_chosen_logps - policy_rejected_logps) - (reference_chosen_logps - reference_rejected_logps) - # logits = (policy_chosen_logps - reference_chosen_logps) - (policy_rejected_logps - reference_rejected_logps) - # logits is the same as rewards_delta in NeMo aligner - # https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 - - if loss_type == "ipo": - losses = (logits - 1 / (2 * beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf - elif loss_type == "rpo": - # https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 - logbeta_hat_chosen = torch.nn.functional.logsigmoid(beta * logits) - logbeta_hat_rejected = torch.nn.functional.logsigmoid(-beta * logits) - gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) - logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) - logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) - losses = torch.exp(logalpha_hat_chosen) * (logalpha_hat_chosen - logbeta_hat_chosen) + torch.exp( - logalpha_hat_rejected - ) * (logalpha_hat_rejected - logbeta_hat_rejected) - elif loss_type == "rpo_sq": - gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) - losses = (beta * logits - gt_rewards_delta) ** 2 - elif loss_type == "dpo": - # Eq. 3 https://ericmitchell.ai/cdpo.pdf; - # label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) - F = torch.nn.functional - losses = ( - -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing + dataset = self.get_dataset(dataset_cfg, dataset_type='train') + sampler = dataset.get_sampler(dataset_cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) + persistent_workers = True + if dataset_cfg.dataloader_params.num_workers == 0: + persistent_workers = False + # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) + dataset.text_tokenizer = setup_tokenizers( + all_tokenizers_config=self.cfg.text_tokenizers, + mode='train', + ) + self._train_dl = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + sampler=sampler, + **dataset_cfg.dataloader_params, + worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers, ) - else: - raise NotImplementedError("loss type {} is not implemented".format(loss_type)) - - chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() - rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() - - return losses, chosen_rewards, rejected_rewards - - def process_batch_dpo(self, batch_chosen_rejected): - """Process a batch for Direct Preference Optimization (DPO) training. - - This method computes the preference loss by comparing the model's policy outputs with a frozen reference model. - It processes chosen and rejected samples, extracts log probabilities for each codebook, and calculates the - preference loss based on the difference in likelihoods between chosen and rejected responses. - - Args: - batch_chosen_rejected (dict): A dictionary containing two keys: - - 'chosen': The batch of chosen responses. - - 'rejected': The batch of rejected responses. - Returns: - dict: A dictionary containing: - - 'loss': The total computed loss. - - 'pref_loss': The preference loss. - - 'sft_loss': The supervised fine-tuning loss. - - 'alignment_loss': The alignment loss, if applicable. - """ - batch_chosen = batch_chosen_rejected['chosen'] - batch_rejected = batch_chosen_rejected['rejected'] - - model_output_chosen = self.process_batch(batch_chosen) - model_output_rejected = self.process_batch(batch_rejected) - with torch.no_grad(): - reference_model_output_chosen = self._reference_model.process_batch(batch_chosen) - reference_model_output_rejected = self._reference_model.process_batch(batch_rejected) - - chosen_policy_logprobs = None - rejected_policy_logprobs = None - chosen_ref_logprobs = None - rejected_ref_logprobs = None - for codebook_idx in range(self.cfg.num_audio_codebooks): - si = codebook_idx * self.cfg.num_audio_tokens_per_codebook - ei = si + self.cfg.num_audio_tokens_per_codebook - codebook_logits_chosen = model_output_chosen['logits'][:, :, si:ei] - codebook_logits_rejected = model_output_rejected['logits'][:, :, si:ei] - - ref_codebook_logits_chosen = reference_model_output_chosen['logits'][:, :, si:ei] - ref_codebook_logits_rejected = reference_model_output_rejected['logits'][:, :, si:ei] - - codebook_labels_chosen = model_output_chosen['audio_codes_target'][:, codebook_idx] - codebook_labels_rejected = model_output_rejected['audio_codes_target'][:, codebook_idx] - - codebook_log_probs_chosen = self._get_batch_logps( - codebook_logits_chosen, codebook_labels_chosen, model_output_chosen['loss_mask'] - ) - codebook_log_probs_rejected = self._get_batch_logps( - codebook_logits_rejected, codebook_labels_rejected, model_output_rejected['loss_mask'] + def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: + if dataset_cfg.get("use_lhotse", False): + # specify target sampling rate the same as codec model's because lhotse config defaults 16_000. + if not isinstance(dataset_cfg, DictConfig): + dataset_cfg = OmegaConf.create(dataset_cfg) + OmegaConf.set_struct(dataset_cfg.dataset, False) + dataset_cfg.dataset.update({"sample_rate": self.sample_rate}) + OmegaConf.set_struct(dataset_cfg.dataset, True) + data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test') + else: + dataset = self.get_dataset(dataset_cfg, dataset_type='test') + persistent_workers = True + if dataset_cfg.dataloader_params.num_workers == 0: + persistent_workers = False + # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) + dataset.text_tokenizer = setup_tokenizers(all_tokenizers_config=self.cfg.text_tokenizers, mode='test') + + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + **dataset_cfg.dataloader_params, + worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers, ) - with torch.no_grad(): - ref_codebook_log_probs_chosen = self._get_batch_logps( - ref_codebook_logits_chosen, codebook_labels_chosen, reference_model_output_chosen['loss_mask'] - ) - ref_codebook_log_probs_rejected = self._get_batch_logps( - ref_codebook_logits_rejected, - codebook_labels_rejected, - reference_model_output_rejected['loss_mask'], - ) - - if chosen_policy_logprobs is None: - chosen_policy_logprobs = codebook_log_probs_chosen - rejected_policy_logprobs = codebook_log_probs_rejected - chosen_ref_logprobs = ref_codebook_log_probs_chosen - rejected_ref_logprobs = ref_codebook_log_probs_rejected - else: - chosen_policy_logprobs += codebook_log_probs_chosen - rejected_policy_logprobs += codebook_log_probs_rejected - chosen_ref_logprobs += ref_codebook_log_probs_chosen - rejected_ref_logprobs += ref_codebook_log_probs_rejected - - rewards_chosen = batch_chosen['rewards'] - rewards_rejected = batch_rejected['rewards'] - - assert torch.all(rewards_chosen == 1) - assert torch.all(rewards_rejected < 1) - - pref_loss, chosen_rewards, rejected_rewards = self.preference_loss( - chosen_policy_logprobs, - rejected_policy_logprobs, - chosen_ref_logprobs, - rejected_ref_logprobs, - chosen_gt_rewards=rewards_chosen, - rejected_gt_rewards=rewards_rejected, - beta=self.cfg.get('dpo_beta', 0.01), - loss_type=self.cfg.get('dpo_loss_type', 'dpo'), - ) - - pref_loss = pref_loss.mean() - sft_loss = -chosen_policy_logprobs.mean() - - pref_loss_weight = self.cfg.get('dpo_pref_loss_weight', 1.0) - sft_loss_weight = self.cfg.get('dpo_sft_loss_weight', 0.0) - loss = pref_loss_weight * pref_loss + sft_loss * sft_loss_weight - - alignment_loss = model_output_chosen['alignment_loss'] - if alignment_loss is not None: - loss += alignment_loss - - return { - 'loss': loss, - 'pref_loss': pref_loss, - 'sft_loss': sft_loss, - 'alignment_loss': alignment_loss, - } - - def training_step(self, batch, batch_idx): - """Perform a training step using DPO loss. - - Args: - batch (dict): Batch data containing chosen and rejected samples. - batch_idx (int): Index of the batch. - - Returns: - Tensor: Training loss. - """ - dpo_outputs = self.process_batch_dpo(batch) - self.log('train_loss', dpo_outputs['loss'], prog_bar=True, sync_dist=True) - self.log('train_pref_loss', dpo_outputs['pref_loss'], prog_bar=True, sync_dist=True) - self.log('train_sft_loss', dpo_outputs['sft_loss'], prog_bar=True, sync_dist=True) - return dpo_outputs['loss'] - - def validation_step(self, batch, batch_idx): - """Perform a validation step using DPO loss. - - Args: - batch (dict): Validation batch data. - batch_idx (int): Batch index. - """ - dpo_outputs = self.process_batch_dpo(batch) - - val_loss = dpo_outputs['loss'] - val_pref_loss = dpo_outputs['pref_loss'] - val_sft_loss = dpo_outputs['sft_loss'] - val_alignment_loss = dpo_outputs['alignment_loss'] - - self.validation_step_outputs.append( - { - 'val_loss': val_loss, - 'val_pref_loss': val_pref_loss, - 'val_sft_loss': val_sft_loss, - 'val_alignment_loss': val_alignment_loss, - } - ) + return data_loader - def on_validation_epoch_end(self): - """Aggregate validation losses at the end of the validation epoch.""" + def setup_validation_data(self, dataset_cfg): + self._validation_dl = self._setup_test_dataloader(dataset_cfg) - def collect(key): - values = [] - for x in self.validation_step_outputs: - if x[key] is not None: - values.append(x[key]) - else: - values.append(torch.tensor(0.0, device=self.device)) - stacked_values = torch.stack(values) - return stacked_values.mean() + def setup_test_data(self, dataset_cfg): + self._test_dl = self._setup_test_dataloader(dataset_cfg) - val_loss = collect("val_loss") - val_pref_loss = collect("val_pref_loss") - val_sft_loss = collect("val_sft_loss") - val_alignment_loss = collect("val_alignment_loss") - self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) - self.log("val_pref_loss", val_pref_loss, prog_bar=True, sync_dist=True) - self.log("val_sft_loss", val_sft_loss, prog_bar=True, sync_dist=True) - if val_alignment_loss is not None: - self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) - self.validation_step_outputs.clear() + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + return [] diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py new file mode 100644 index 000000000000..e2506d08e497 --- /dev/null +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -0,0 +1,1106 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import json +import os +import random +import string +from typing import Optional + +import librosa +import numpy as np +import soundfile as sf +import torch +from lightning.pytorch import Trainer +from omegaconf import DictConfig, open_dict + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors +from nemo.utils import logging + +try: + import torchaudio + from torchaudio.pipelines import SQUIM_OBJECTIVE + + HAVE_TORCHAUDIO = True +except ImportError: + HAVE_TORCHAUDIO = False + +try: + from nemo_text_processing.text_normalization.normalize import Normalizer + + PYNINI_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + Normalizer = None + PYNINI_AVAILABLE = False + +from nemo.collections.tts.models import MagpieTTSModel + + +class MagpieTTSModelOfflinePODataGen(MagpieTTSModel): + """Small override of MagpieTTSModel for parallel multi-GPU inference and metrics calculation. + This class is used in 'test' mode and leverages trainer.test() for multi-GPU/multi-node inference. + Saves the predicted audio files and logs the CER/WER metrics as individual json files for each audio. + """ + + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): + super().__init__(cfg, trainer) + if cfg.get('pref_set_language', "en") == "en": + self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( + model_name="nvidia/parakeet-ctc-0.6b" + ) + self.eval_asr_model.freeze() + + self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name='titanet_large' + ) + self.eval_speaker_verification_model.freeze() + + if cfg.get('load_whisper_model', False): + from transformers import WhisperForConditionalGeneration, WhisperProcessor + + self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + self.whisper_model.eval() + self._normalize_whisper_transcript = cfg.get('normalize_whisper_transcript', True) + if self._normalize_whisper_transcript and PYNINI_AVAILABLE: + self._normalizer_cache = {} + # Pre-create normalizer for the configured language + lang = cfg.get('pref_set_language', 'en') + self._get_cached_normalizer(lang) + + def _get_cached_normalizer(self, lang_key): + """Get or create a cached normalizer for the given language.""" + if not PYNINI_AVAILABLE: + return None + lang_key = lang_key if lang_key else "en" + if lang_key not in self._normalizer_cache: + logging.info(f"Creating normalizer for language: {lang_key}") + self._normalizer_cache[lang_key] = Normalizer(input_case="cased", lang=lang_key) + return self._normalizer_cache[lang_key] + + def test_step(self, batch, batch_idx): + with torch.no_grad(): + test_dl_batch_size = self._test_dl.batch_size + temperature = self.cfg.get('inference_temperature', 0.7) + topk = self.cfg.get('inference_topk', 80) + use_cfg = self.cfg.get('inference_use_cfg', False) + cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( + batch, + max_decoder_steps=self.cfg.get('max_decoder_steps', 500), + temperature=temperature, + topk=topk, + use_cfg=use_cfg, + cfg_scale=cfg_scale, + ) + predicted_audio_paths = [] + audio_durations = [] + batch_invalid = False + for idx in range(predicted_audio.size(0)): + predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() + predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] + item_idx = batch_idx * test_dl_batch_size + idx + # Save the predicted audio + log_dir = self.logger.log_dir + audio_dir = os.path.join(log_dir, 'audios') + if not os.path.exists(audio_dir): + os.makedirs(audio_dir) + audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') + audio_durations.append(len(predicted_audio_np) / self.sample_rate) + sf.write(audio_path, predicted_audio_np, self.sample_rate) + + predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) + predicted_codes_torch = predicted_codes_torch[:, : predicted_codes_lens[idx]] + torch.save( + predicted_codes_torch, + os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt'), + ) + predicted_audio_paths.append(audio_path) + + if not batch_invalid: + with torch.no_grad(): + try: + if self.cfg.get("pref_set_language", "en") == "en": + pred_transcripts = self.eval_asr_model.transcribe( + predicted_audio_paths, batch_size=len(predicted_audio_paths) + ) + pred_transcripts = [ + process_text_for_cer(transcript.text) for transcript in pred_transcripts + ] + else: + pred_transcripts = [] + for audio_path in predicted_audio_paths: + normalizer = ( + self._get_cached_normalizer(self.cfg.pref_set_language) + if self._normalize_whisper_transcript + else None + ) + transcript = transcribe_with_whisper( + audio_path, + self.cfg.pref_set_language, + self.whisper_processor, + self.whisper_model, + self.device, + normalizer, + ) + pred_transcripts.append(transcript) + + pred_transcripts = [ + process_text_for_cer(transcript) for transcript in pred_transcripts + ] + except Exception as e: + assert ( + predicted_audio_lens[idx] < 1000 + ).any(), f"Expected short audio file to be the only cause of ASR errors, but got error with lengths {predicted_audio_lens}" + logging.warning(f"Exception during ASR transcription: {e}") + logging.warning( + "Skipping processing of the batch; generating metrics indicating a WER of 100% and Speaker Similarity of 0.0" + ) + batch_invalid = True + continue # don't break since we want to continue building audio durations list + pred_speaker_embeddings = get_speaker_embeddings_from_filepaths( + predicted_audio_paths, self.eval_speaker_verification_model, self.device + ) + gt_speaker_embeddings = get_speaker_embeddings_from_filepaths( + batch['audio_filepaths'], self.eval_speaker_verification_model, self.device + ) + + for idx in range(predicted_audio.size(0)): + if not batch_invalid: + item_idx = batch_idx * test_dl_batch_size + idx + pred_transcript = pred_transcripts[idx] + gt_transcript = process_text_for_cer(batch['raw_texts'][idx]) + + cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True) + wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False) + + spk_embedding_pred = pred_speaker_embeddings[idx].cpu().numpy() + spk_embedding_gt = gt_speaker_embeddings[idx].cpu().numpy() + + spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) + ) + else: + # Create an entry indicating invalid metrics + cer_gt = 1.0 + wer_gt = 1.0 + spk_similarity = 0.0 + pred_transcript = "" # do not change this string; subsequent processing relies on it + gt_transcript = process_text_for_cer(batch['raw_texts'][idx]) + + item_metrics = { + 'cer_gt': float(cer_gt), + 'wer_gt': float(wer_gt), + 'duration': audio_durations[idx], + 'spk_similarity': float(spk_similarity), + 'pred_transcript': pred_transcript, + 'gt_transcript': gt_transcript, + } + + with open( + os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w' + ) as f: + json.dump(item_metrics, f) + + +class MagpieTTSModelOfflinePO(MagpieTTSModel): + """ + MagpieTTS_Model_OfflinePO is a class that extends MagpieTTS_Model to support + offline preference optimization (DPO, IPO, RPO). + Set cfg.model.dpo_loss_type to 'dpo', 'ipo', or 'rpo' to use the corresponding loss. + """ + + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): + super().__init__(cfg, trainer) + ref_model_cfg = copy.deepcopy(cfg) + with open_dict(ref_model_cfg): + ref_model_cfg.train_ds = None + ref_model_cfg.validation_ds = None + self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) + print("Loading reference model from checkpoint") + self._reference_model.load_state_dict( + torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict'] + ) + self._reference_model.freeze() + self._reference_model._no_state_dict = True + print("Reference model loaded and frozen") + + def state_dict(self, destination=None, prefix='', keep_vars=False): + state_dict = super().state_dict(destination, prefix, keep_vars) + keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model'] + for key in list(state_dict.keys()): + if any([substring in key for substring in keys_substrings_to_exclude]): + del state_dict[key] + return state_dict + + def _get_batch_logps(self, logits, labels, loss_mask, average_log_prob=False): + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def preference_loss( + self, + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + chosen_gt_rewards=None, + rejected_gt_rewards=None, + beta=0.2, + gt_reward_scale=1.0, + label_smoothing=0, + loss_type="dpo", + reference_free=False, + ): + """Compute the DPO loss for a batch of policy and reference model log probabilities. + Referenced From: https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. + label_smoothing: conservativeness for DPO loss, which assumes that preferences are noisy (flipped with probability label_smoothing) + ipo: If True, use the IPO loss instead of the DPO loss. + reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + if reference_free: + ref_logratios = 0 + + logits = pi_logratios - ref_logratios # also known as h_{\pi_\theta}^{y_w,y_l} + # logits = (policy_chosen_logps - policy_rejected_logps) - (reference_chosen_logps - reference_rejected_logps) + # logits = (policy_chosen_logps - reference_chosen_logps) - (policy_rejected_logps - reference_rejected_logps) + # logits is the same as rewards_delta in NeMo aligner: https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 + + if loss_type == "ipo": + losses = (logits - 1 / (2 * beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf + elif loss_type == "rpo": + # https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 + logbeta_hat_chosen = torch.nn.functional.logsigmoid(beta * logits) + logbeta_hat_rejected = torch.nn.functional.logsigmoid(-beta * logits) + gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) + logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) + logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) + losses = torch.exp(logalpha_hat_chosen) * (logalpha_hat_chosen - logbeta_hat_chosen) + torch.exp( + logalpha_hat_rejected + ) * (logalpha_hat_rejected - logbeta_hat_rejected) + elif loss_type == "rpo_sq": + gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) + losses = (beta * logits - gt_rewards_delta) ** 2 + elif loss_type == "dpo": + # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) + F = torch.nn.functional + losses = ( + -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing + ) + else: + raise NotImplementedError("loss type {} is not implemented".format(loss_type)) + + chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards + + def process_batch_dpo(self, batch_chosen_rejected): + batch_chosen = batch_chosen_rejected['chosen'] + batch_rejected = batch_chosen_rejected['rejected'] + + model_output_chosen = self.process_batch(batch_chosen) + model_output_rejected = self.process_batch(batch_rejected) + with torch.no_grad(): + reference_model_output_chosen = self._reference_model.process_batch(batch_chosen) + reference_model_output_rejected = self._reference_model.process_batch(batch_rejected) + + chosen_policy_logprobs = None + rejected_policy_logprobs = None + chosen_ref_logprobs = None + rejected_ref_logprobs = None + for codebook_idx in range(self.num_audio_codebooks): + si = codebook_idx * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook + codebook_logits_chosen = model_output_chosen['logits'][:, :, si:ei] + codebook_logits_rejected = model_output_rejected['logits'][:, :, si:ei] + + ref_codebook_logits_chosen = reference_model_output_chosen['logits'][:, :, si:ei] + ref_codebook_logits_rejected = reference_model_output_rejected['logits'][:, :, si:ei] + + codebook_labels_chosen = model_output_chosen['audio_codes_target'][:, codebook_idx] + codebook_labels_rejected = model_output_rejected['audio_codes_target'][:, codebook_idx] + + codebook_log_probs_chosen = self._get_batch_logps( + codebook_logits_chosen, codebook_labels_chosen, model_output_chosen['loss_mask'][:, codebook_idx] + ) + codebook_log_probs_rejected = self._get_batch_logps( + codebook_logits_rejected, codebook_labels_rejected, model_output_rejected['loss_mask'][:, codebook_idx] + ) + with torch.no_grad(): + ref_codebook_log_probs_chosen = self._get_batch_logps( + ref_codebook_logits_chosen, + codebook_labels_chosen, + reference_model_output_chosen['loss_mask'][:, codebook_idx], + ) + ref_codebook_log_probs_rejected = self._get_batch_logps( + ref_codebook_logits_rejected, + codebook_labels_rejected, + reference_model_output_rejected['loss_mask'][:, codebook_idx], + ) + + if chosen_policy_logprobs is None: + chosen_policy_logprobs = codebook_log_probs_chosen + rejected_policy_logprobs = codebook_log_probs_rejected + chosen_ref_logprobs = ref_codebook_log_probs_chosen + rejected_ref_logprobs = ref_codebook_log_probs_rejected + else: + chosen_policy_logprobs += codebook_log_probs_chosen + rejected_policy_logprobs += codebook_log_probs_rejected + chosen_ref_logprobs += ref_codebook_log_probs_chosen + rejected_ref_logprobs += ref_codebook_log_probs_rejected + + rewards_chosen = batch_chosen['rewards'] + rewards_rejected = batch_rejected['rewards'] + + assert torch.all(rewards_chosen == 1) + assert torch.all(rewards_rejected < 1) + + pref_loss, chosen_rewards, rejected_rewards = self.preference_loss( + chosen_policy_logprobs, + rejected_policy_logprobs, + chosen_ref_logprobs, + rejected_ref_logprobs, + chosen_gt_rewards=rewards_chosen, + rejected_gt_rewards=rewards_rejected, + beta=self.cfg.get('dpo_beta', 0.01), + loss_type=self.cfg.get('dpo_loss_type', 'dpo'), + ) + + pref_loss = pref_loss.mean() + sft_loss = -chosen_policy_logprobs.mean() + + pref_loss_weight = self.cfg.get('dpo_pref_loss_weight', 1.0) + sft_loss_weight = self.cfg.get('dpo_sft_loss_weight', 0.0) + loss = pref_loss_weight * pref_loss + sft_loss * sft_loss_weight + + alignment_loss = model_output_chosen['alignment_loss'] + if alignment_loss is not None: + loss += alignment_loss + + return { + 'loss': loss, + 'pref_loss': pref_loss, + 'sft_loss': sft_loss, + 'alignment_loss': alignment_loss, + } + + def training_step(self, batch, batch_idx): + dpo_outputs = self.process_batch_dpo(batch) + self.log('train_loss', dpo_outputs['loss'], prog_bar=True, sync_dist=True) + self.log('train_pref_loss', dpo_outputs['pref_loss'], prog_bar=True, sync_dist=True) + self.log('train_sft_loss', dpo_outputs['sft_loss'], prog_bar=True, sync_dist=True) + return dpo_outputs['loss'] + + def validation_step(self, batch, batch_idx): + dpo_outputs = self.process_batch_dpo(batch) + + val_loss = dpo_outputs['loss'] + val_pref_loss = dpo_outputs['pref_loss'] + val_sft_loss = dpo_outputs['sft_loss'] + val_alignment_loss = dpo_outputs['alignment_loss'] + + self.validation_step_outputs.append( + { + 'val_loss': val_loss, + 'val_pref_loss': val_pref_loss, + 'val_sft_loss': val_sft_loss, + 'val_alignment_loss': val_alignment_loss, + } + ) + + def on_validation_epoch_end(self): + def collect(key): + values = [] + for x in self.validation_step_outputs: + if x[key] is not None: + values.append(x[key]) + else: + values.append(torch.tensor(0.0, device=self.device)) + stacked_values = torch.stack(values) + return stacked_values.mean() + + val_loss = collect("val_loss") + val_pref_loss = collect("val_pref_loss") + val_sft_loss = collect("val_sft_loss") + val_alignment_loss = collect("val_alignment_loss") + self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) + self.log("val_pref_loss", val_pref_loss, prog_bar=True, sync_dist=True) + self.log("val_sft_loss", val_sft_loss, prog_bar=True, sync_dist=True) + if val_alignment_loss is not None: + self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) + self.validation_step_outputs.clear() + + +class MagpieTTSModelOnlinePO(MagpieTTSModel): + """ + MagpieTTS_Model_OnlinePO is a class that extends MagpieTTS_Model to support + online preference optimization (GRPO). + """ + + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): + super().__init__(cfg, trainer) + # Copy cfg + ref_model_cfg = copy.deepcopy(cfg) + with open_dict(ref_model_cfg): + ref_model_cfg.train_ds = None + ref_model_cfg.validation_ds = None + + self.reference_free = self.cfg.get('reference_free', False) # True means we dont use the reference model + if not self.reference_free: + self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) + print("Loading reference model from checkpoint") + self._reference_model.load_state_dict( + torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict'] + ) + self._reference_model.freeze() + self._reference_model._no_state_dict = True + print("Reference model loaded and frozen") + + if cfg.get('reward_asr_model', "nemo") == "nemo": + self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( + model_name="nvidia/parakeet-ctc-0.6b" + ) + self.eval_asr_model.freeze() + elif cfg.get('reward_asr_model', "nemo") == "whisper": + from transformers import WhisperForConditionalGeneration, WhisperProcessor + + self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + self.whisper_model.eval() + else: + raise ValueError(f"Unknown reward_asr_model: {cfg.reward_asr_model}") + + self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name='titanet_large' + ) + self.eval_speaker_verification_model.freeze() + + if cfg.get('load_whisper_model', False): + from transformers import WhisperForConditionalGeneration, WhisperProcessor + + self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + self.whisper_model.eval() + + use_pesq = self.cfg.get('use_pesq', False) + if use_pesq: + assert HAVE_TORCHAUDIO, "torchaudio is required for PESQ reward" + self.squim_objective_model = SQUIM_OBJECTIVE.get_model() + + self.loss_type = self.cfg.get('loss_type', 'grpo') + if self.loss_type not in ['grpo', 'dr_grpo']: + raise ValueError( + f"Received loss_type of {self.loss_type}, but the model only accepts one of ['grpo', 'dr_grpo']" + ) + self.scale_rewards = self.cfg.get('scale_rewards', True) + self.max_decoder_steps = self.cfg.get('max_decoder_steps', 430) + + self._normalize_whisper_transcript = self.cfg.get('normalize_whisper_transcript', True) + if cfg.get('reward_asr_model', "nemo") == "whisper" and self._normalize_whisper_transcript: + self._normalizer_cache = {} + + # If the best record in the group is above this threshold, we will not use that group for training + # Setting this to 1.0, because we clamp the ASR rewards to be in [0, 1] for OnlinePO + self.best_cer_threshold = self.cfg.get('best_cer_threshold', 1.0) + # If the worst record in the group exceeds this threshold, we will not use that group for training + # Setting this to 1.0, because we clamp the ASR rewards to be in [0, 1] for OnlinePO + self.worst_cer_threshold = self.cfg.get('worst_cer_threshold', 1.0) + + def _get_cached_normalizer(self, lang_key): + """Get or create a cached normalizer for the given language.""" + if not PYNINI_AVAILABLE: + return None + lang_key = lang_key if lang_key else "en" + if lang_key not in self._normalizer_cache: + logging.info(f"Creating normalizer for language: {lang_key}") + self._normalizer_cache[lang_key] = Normalizer(input_case="cased", lang=lang_key) + return self._normalizer_cache[lang_key] + + def state_dict(self, destination=None, prefix='', keep_vars=False): + state_dict = super().state_dict(destination, prefix, keep_vars) + keys_substrings_to_exclude = [ + '_speaker_verification_model', + '_codec_model', + '_reference_model', + 'eval_asr_model', + 'eval_speaker_verification_model', + 'whisper_model', + ] + for key in list(state_dict.keys()): + if any([substring in key for substring in keys_substrings_to_exclude]): + del state_dict[key] + return state_dict + + def _get_per_token_logps(self, logits, labels, loss_mask): + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. + """ + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = per_token_logps * loss_mask + return per_token_logps + + def repeat_items_in_batch(self, batch, num_repeats): + repeated_batch = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + repeated_value = value.repeat_interleave(num_repeats, dim=0) + elif isinstance(value, list): + repeated_value = [] + for item in value: + repeated_value.extend([item] * num_repeats) + else: + repeated_value = value + repeated_batch[key] = repeated_value + return repeated_batch + + def generate_and_reward( + self, batch, num_generations_per_item, mode='train', use_local_transformer_for_inference=False + ): + batch_repeated = self.repeat_items_in_batch(batch, num_generations_per_item) + temperature = self.cfg.get('inference_temperature', 0.7) + topk = self.cfg.get('inference_topk', 80) + use_cfg = False + cfg_scale = 1.0 + use_pesq = self.cfg.get('use_pesq', False) + inference_cfg_prob = self.cfg.get('inference_cfg_prob', 0.0) + if (inference_cfg_prob == 1.0) or (inference_cfg_prob > 0.0 and mode == 'train'): + # Randomly set use_cfg based on the given probability + use_cfg = random.random() < self.cfg.inference_cfg_prob + cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) + + predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch( + batch_repeated, + max_decoder_steps=self.max_decoder_steps, + temperature=temperature, + topk=topk, + use_cfg=use_cfg, + cfg_scale=cfg_scale, + use_local_transformer_for_inference=use_local_transformer_for_inference, + use_LT_kv_cache=False, # We don't use KV caching for local transformer in GRPO due to issues. + ) + predicted_audio_paths = [] + audio_durations = [] + for idx in range(predicted_audio.size(0)): + predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() + predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] + if predicted_audio_np.shape[0] < 1000: + # Corner case to handle short audio files + predicted_audio_np = np.pad(predicted_audio_np, (0, 1000 - predicted_audio_np.shape[0])) + item_idx = idx + # Save the predicted audio + log_dir = self.logger.log_dir + audio_dir = os.path.join(log_dir, 'audios') + os.makedirs(audio_dir, exist_ok=True) + audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') + audio_durations.append(len(predicted_audio_np) / self.sample_rate) + sf.write(audio_path, predicted_audio_np, self.sample_rate) + + predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) + predicted_codes_torch = predicted_codes_torch[:, : predicted_codes_lens[idx]] # C, T + torch.save( + predicted_codes_torch, + os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt'), + ) + predicted_audio_paths.append(audio_path) + + with torch.no_grad(): + if self.cfg.get("reward_asr_model", "nemo") == "nemo": + pred_transcripts = self.eval_asr_model.transcribe( + predicted_audio_paths, batch_size=len(predicted_audio_paths) + ) + pred_transcripts = [process_text_for_cer(transcript.text) for transcript in pred_transcripts] + elif self.cfg.get("reward_asr_model", "nemo") == "whisper": + pred_transcripts = [] + for item_idx, audio_path in enumerate(predicted_audio_paths): + language = batch_repeated['languages'][item_idx] + normalizer = self._get_cached_normalizer(language) if self._normalize_whisper_transcript else None + transcript = transcribe_with_whisper( + audio_path, language, self.whisper_processor, self.whisper_model, self.device, normalizer + ) + pred_transcripts.append(transcript) + pred_transcripts = [process_text_for_cer(transcript) for transcript in pred_transcripts] + else: + # Address CodeQL issue where pred_transcripts might be undefined for future code + raise ValueError( + f"{self} received a value of {self.cfg.get('reward_asr_model', 'nemo')} in cfg.reward_asr_model " + "but this class only supports 'nemo' or 'whisper'." + ) + + pred_speaker_embeddings = get_speaker_embeddings_from_filepaths( + predicted_audio_paths, self.eval_speaker_verification_model, self.device + ) + gt_speaker_embeddings = get_speaker_embeddings_from_filepaths( + batch_repeated['audio_filepaths'], self.eval_speaker_verification_model, self.device + ) + + batch_metrics = [] + cer_reward_weight = self.cfg.get('cer_reward_weight', 0.5) + ssim_reward_weight = self.cfg.get('ssim_reward_weight', 0.5) + pesq_reward_weight = self.cfg.get('pesq_reward_weight', 0.0) + for idx in range(predicted_audio.size(0)): + audio_path = predicted_audio_paths[idx] + item_idx = idx + pred_transcript = pred_transcripts[idx] + gt_transcript = process_text_for_cer(batch_repeated['raw_texts'][idx]) + cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True) + wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False) + cer_gt = min(max(cer_gt, 0.0), 1.0) # Ensure CER is in [0, 1] + wer_gt = min(max(wer_gt, 0.0), 1.0) # Ensure WER is in [0, 1] + spk_embedding_pred = pred_speaker_embeddings[idx].cpu().float().numpy() + spk_embedding_gt = gt_speaker_embeddings[idx].cpu().float().numpy() + spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) + ) + if use_pesq: + sample_audio, sr = torchaudio.load(audio_path) + sample_audio = sample_audio.to(self.device) + if sr != 16000: + sample_audio = torchaudio.functional.resample(sample_audio, sr, 16000) + _, pesq_hyp, _ = self.squim_objective_model(sample_audio) + pesq_hyp = pesq_hyp.item() + + item_metrics = { + 'cer_gt': float(cer_gt), + 'wer_gt': float(wer_gt), + 'duration': audio_durations[idx], + 'spk_similarity': float(spk_similarity), + 'pred_transcript': pred_transcript, + 'gt_transcript': gt_transcript, + 'codes_len': predicted_codes_lens[idx].item(), + 'pesq': pesq_hyp if use_pesq else 0.0, + } + with open( + os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w' + ) as f: + json.dump(item_metrics, f) + + batch_metrics.append(item_metrics) + + num_groups = len(batch['audio_filepaths']) + + best_ssim_achievable = self.cfg.get( + "best_ssim_achievable", 0.9 + ) # Examples with this speaker similarity or higher will have SSIM reward of 1 + mean_cer_dataset = self.cfg.get("mean_cer_dataset", 0.1) # CER equal to this value will have reward of 0.5 + mean_ssim_dataset = self.cfg.get("mean_ssim_dataset", 0.6) # SSIM equal to this value will have reward of 0.5 + all_groups_mean_reward = 0.0 + all_groups_std_reward = 0.0 + group_validities = [] + for group_idx in range(num_groups): + group_start_idx = group_idx * num_generations_per_item + group_end_idx = group_start_idx + num_generations_per_item + group_rewards = [] + mean_reward = 0 + is_group_valid = True + group_best_cer = 1.0 + group_worst_cer = 0.0 + for idx in range(group_start_idx, group_end_idx): + # Lower CER and higher speaker similarity is better, means high reward + # Higher pesq is better, means high reward + # Reward for best CER and best speaker similarity should be 1 + item_cer = batch_metrics[idx]['cer_gt'] + item_ssim = batch_metrics[idx]['spk_similarity'] + item_cer = min(max(item_cer, 0.0), 1.0) + item_ssim = max(min(item_ssim, best_ssim_achievable), 0.0) + item_pesq = batch_metrics[idx]['pesq'] + group_best_cer = min(group_best_cer, item_cer) + group_worst_cer = max(group_worst_cer, item_cer) + + if item_cer <= mean_cer_dataset: + cer_reward = 0.5 + 0.5 * (mean_cer_dataset - item_cer) / mean_cer_dataset # 0.5 to 1 + else: + cer_reward = 0.5 - 0.5 * (item_cer - mean_cer_dataset) / (1 - mean_cer_dataset) # 0 to 0.5 + if item_ssim >= mean_ssim_dataset: + spk_similarity_reward = 0.5 + 0.5 * (item_ssim - mean_ssim_dataset) / ( + best_ssim_achievable - mean_ssim_dataset + ) + else: + spk_similarity_reward = 0.5 - 0.5 * (mean_ssim_dataset - item_ssim) / (mean_ssim_dataset) + if use_pesq: + pesq_reward = item_pesq / 4.5 + else: + pesq_reward = 0.0 + + batch_metrics[idx]['reward'] = ( + cer_reward * cer_reward_weight + + spk_similarity_reward * ssim_reward_weight + + pesq_reward * pesq_reward_weight + ) + + if (batch_metrics[idx]['codes_len'] >= 425) or ( + batch_metrics[idx]['codes_len'] <= 3 + ): # TODO: Remove hardcoded lengths + # This means it did not complete the sentence or generated an extremely short sentence + batch_metrics[idx]['reward'] = 0.0 + print( + "Item idx: ", + idx, + " CER: ", + item_cer, + " SSIM: ", + item_ssim, + " Reward: ", + batch_metrics[idx]['reward'], + " Codes len: ", + batch_metrics[idx]['codes_len'], + ) + batch_metrics[idx]['cer_reward'] = cer_reward + batch_metrics[idx]['spk_similarity_reward'] = spk_similarity_reward + batch_metrics[idx]['pesq_reward'] = pesq_reward + mean_reward += batch_metrics[idx]['reward'] + group_rewards.append(batch_metrics[idx]['reward']) + + if group_best_cer > self.best_cer_threshold: + is_group_valid = False + print( + f"Group {group_idx} has best CER {group_best_cer} which is above the threshold {self.best_cer_threshold}. Group is invalid." + ) + + if group_worst_cer > self.worst_cer_threshold: + is_group_valid = False + print( + f"Group {group_idx} has worst CER {group_worst_cer} which is above the threshold {self.worst_cer_threshold}. Group is invalid." + ) + + for _ in range(num_generations_per_item): + group_validities.append(is_group_valid) + + mean_reward /= num_generations_per_item + std_reward = np.std(group_rewards) + all_groups_mean_reward += mean_reward + all_groups_std_reward += std_reward + for idx in range(group_start_idx, group_end_idx): + batch_metrics[idx]['advantage'] = batch_metrics[idx]['reward'] - mean_reward + if self.scale_rewards: + batch_metrics[idx]['advantage'] = batch_metrics[idx]['advantage'] / (std_reward + 1e-4) + + all_groups_mean_reward = all_groups_mean_reward / num_groups + all_groups_std_reward = all_groups_std_reward / num_groups + advantages = [x['advantage'] for x in batch_metrics] + advantages = torch.tensor(advantages, device=self.device) + print("Mean reward: ", all_groups_mean_reward) + + group_validities = torch.tensor(group_validities, device=self.device) + return { + 'mean_reward': torch.tensor(all_groups_mean_reward, device=self.device), + 'std_reward': torch.tensor(all_groups_std_reward, device=self.device), + 'batch_repeated': batch_repeated, + 'metrics': batch_metrics, + 'predicted_codes': predicted_codes, + 'predicted_codes_lens': predicted_codes_lens, + 'advantages': advantages, + 'group_validities': group_validities, + } + + def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): + use_kv_cache_during_online_po = self.cfg.get("use_kv_cache_during_online_po", False) + if use_kv_cache_during_online_po: + self.use_kv_cache_for_inference = True + self.decoder.reset_cache(use_cache=True) + + use_local_transformer_for_inference = False + logits_key = 'logits' + use_local_transformer_prob = self.cfg.get('use_local_transformer_prob', 0.0) + if use_local_transformer_prob > 0.0 and mode == 'train': + use_local_transformer_for_inference = random.random() < use_local_transformer_prob + logits_key = 'local_transformer_logits' + + with torch.no_grad(): + self.eval() + generated_codes_and_metrics = self.generate_and_reward( + batch, + n_generations_per_item, + mode, + use_local_transformer_for_inference=use_local_transformer_for_inference, + ) + self.train() + + if use_kv_cache_during_online_po: + self.use_kv_cache_for_inference = False + self.decoder.reset_cache(use_cache=False) + + batch_repeated = generated_codes_and_metrics['batch_repeated'] + predicted_codes = generated_codes_and_metrics['predicted_codes'] # B, 8, T + predicted_codes_lens = generated_codes_and_metrics['predicted_codes_lens'] # B + predicted_codes = predicted_codes[:, :, : predicted_codes_lens.max()] + + advantages = generated_codes_and_metrics['advantages'] # B + # Add extra tokens for BOS and EOS + bos_tensor = torch.full( + (predicted_codes.size(0), predicted_codes.size(1), 1), + self.audio_bos_id, + dtype=predicted_codes.dtype, + device=predicted_codes.device, + ) + padding_tensor = torch.full( + (predicted_codes.size(0), predicted_codes.size(1), 1), + 0, + dtype=predicted_codes.dtype, + device=predicted_codes.device, + ) + predicted_codes = torch.cat([bos_tensor, predicted_codes, padding_tensor], dim=2) + for idx in range(predicted_codes.size(0)): + predicted_codes[idx, :, predicted_codes_lens[idx] + 1] = self.audio_eos_id # Accounts for BOS + batch_repeated['audio_codes'] = predicted_codes + batch_repeated['audio_codes_lens'] = predicted_codes_lens + 2 # Accounts for BOS and EOS + if 'audio' in batch_repeated: + del batch_repeated['audio'] + if 'audio_lens' in batch_repeated: + del batch_repeated['audio_lens'] + + policy_model_outputs = self.process_batch(batch_repeated) + + reference_model_output = ( + None # Address CodeQL issue even though this varibable is only used not self.reference_free + ) + if not self.reference_free: + with torch.no_grad(): + reference_model_output = self._reference_model.process_batch(batch_repeated) + + total_loss = None + total_kl = None + for codebook_idx in range(self.num_audio_codebooks): + policy_codebook_loss_mask = policy_model_outputs['loss_mask'][:, codebook_idx, :] + reference_codebook_loss_mask = ( + reference_model_output['loss_mask'][:, codebook_idx, :] if not self.reference_free else None + ) + si = codebook_idx * self.num_all_tokens_per_codebook + ei = si + self.num_all_tokens_per_codebook + + codebook_logits = policy_model_outputs[logits_key][:, :, si:ei] # B, T, C + codebook_labels = batch_repeated['audio_codes'][:, codebook_idx, 1:] + + per_token_codebook_log_probs = self._get_per_token_logps( + codebook_logits, codebook_labels, policy_codebook_loss_mask + ) + per_token_loss = -( + torch.exp(per_token_codebook_log_probs - per_token_codebook_log_probs.detach()) + * advantages.unsqueeze(1) + ) + group_validities = generated_codes_and_metrics['group_validities'] # B * n_generations_per_item + per_token_loss = per_token_loss * group_validities.unsqueeze(1) # B, T + + if not self.reference_free: + with torch.no_grad(): + ref_codebook_logits = reference_model_output[logits_key][:, :, si:ei] + per_token_ref_codebook_log_probs = self._get_per_token_logps( + ref_codebook_logits, codebook_labels, reference_codebook_loss_mask + ) + # https://github.com/huggingface/trl/blob/ffcb9f4aee725a2bd072d0387afe68a4b1c7967c/trl/trainer/grpo_trainer.py#L703 + per_token_codebook_kl = ( + torch.exp(per_token_ref_codebook_log_probs - per_token_codebook_log_probs) + - (per_token_ref_codebook_log_probs - per_token_codebook_log_probs) + - 1 + ) + per_token_loss = per_token_loss + self.cfg.grpo_beta * per_token_codebook_kl + codebook_kl_loss_mean = ( + (per_token_codebook_kl * policy_codebook_loss_mask).sum(dim=1) + / policy_codebook_loss_mask.sum(dim=1) + ).mean() + else: + codebook_kl_loss_mean = torch.tensor(0.0, device=self.device) + + if self.loss_type == "grpo": + codebook_loss = ( + (per_token_loss * policy_codebook_loss_mask).sum(dim=1) / policy_codebook_loss_mask.sum(dim=1) + ).mean() + elif self.loss_type == "dr_grpo": + # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py + total_tokens = per_token_loss.shape[0] * self.max_decoder_steps + codebook_loss = (per_token_loss * policy_codebook_loss_mask).sum() / total_tokens + else: + raise ValueError(f"Unknown loss function: {self.loss_type}") + + if total_loss is None: + total_loss = codebook_loss + total_kl = codebook_kl_loss_mean + else: + total_loss += codebook_loss + total_kl += codebook_kl_loss_mean + + total_loss /= self.num_audio_codebooks + + return { + 'mean_reward': generated_codes_and_metrics['mean_reward'], + 'std_reward': generated_codes_and_metrics['std_reward'], + 'loss': total_loss, + 'kl_loss': total_kl, + 'batch_metrics': generated_codes_and_metrics['metrics'], + } + + def training_step(self, batch, batch_idx): + torch.cuda.empty_cache() + n_generations_per_item = self.cfg.get('n_generations_per_item', 6) + po_outputs = self.process_batch_online_po(batch, n_generations_per_item) + self.log('train_loss', po_outputs['loss'], prog_bar=True, sync_dist=True) + self.log('train_kl_loss', po_outputs['kl_loss'], prog_bar=True, sync_dist=True) + self.log('train_mean_reward', po_outputs['mean_reward'], prog_bar=True, sync_dist=True) + self.log('train_std_reward', po_outputs['std_reward'], prog_bar=True, sync_dist=True) + return po_outputs['loss'] + + def validation_step(self, batch, batch_idx): + po_outputs = self.process_batch_online_po(batch, 1, mode='val') + batch_metrics = po_outputs['batch_metrics'] + mean_reward = po_outputs['mean_reward'] + val_loss = po_outputs['loss'] + val_kl_loss = po_outputs['kl_loss'] + + self.validation_step_outputs.append( + { + 'mean_reward': mean_reward, + 'std_reward': po_outputs['std_reward'], + 'val_loss': val_loss, + 'val_kl_loss': val_kl_loss, + 'batch_metrics': batch_metrics, + } + ) + + def on_validation_epoch_end(self): + def collect(key): + values = [] + for x in self.validation_step_outputs: + if x[key] is not None: + values.append(x[key]) + else: + values.append(torch.tensor(0.0, device=self.device)) + stacked_values = torch.stack(values) + return stacked_values.mean() + + val_loss = collect("val_loss") + val_kl_loss = collect("val_kl_loss") + mean_reward = collect("mean_reward") + std_reward = collect("std_reward") + + self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) + self.log("val_kl_loss", val_kl_loss, prog_bar=True, sync_dist=True) + self.log("val_mean_reward", mean_reward, prog_bar=True, sync_dist=True) + self.log("val_std_reward", std_reward, prog_bar=True, sync_dist=True) + + mean_metrics = {} + for val_output in self.validation_step_outputs: + batch_metrics = val_output['batch_metrics'] + for item_metrics in batch_metrics: + for key, value in item_metrics.items(): + if "transcript" not in key: + if key not in mean_metrics: + mean_metrics[key] = [] + mean_metrics[key].append(value) + + for key, values in mean_metrics.items(): + mean_metrics[key] = np.mean(values) + self.log(f"val_{key}", mean_metrics[key], prog_bar=True, sync_dist=True) + + self.validation_step_outputs.clear() + + +# Utility functions +def process_text_for_cer(input_text): + """ + Normalizes text for CER/WER calculation. + Taken from hallucination_eval.py + """ + # Convert text to lowercase + lower_case_text = input_text.lower() + + # Remove commas from text + no_comma_text = lower_case_text.replace(",", "") + # Replace "-" with spaces + no_dash_text = no_comma_text.replace("-", " ") + no_dash_text = no_dash_text.replace("'", "") + no_dash_text = no_dash_text.replace(";", "") + no_dash_text = no_dash_text.replace(".", "") + + # Replace double spaces with single space + single_space_text = " ".join(no_dash_text.split()) + + single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) + + # @shehzeen: Added this to handle some common errors in ASR transcripts + single_space_text = single_space_text.replace("h t t p", "http") + single_space_text = single_space_text.replace("w w w", "www") + + return single_space_text + + +def get_speaker_embeddings_from_filepaths(filepaths, speaker_verification_model, device): + audio_batch = [] + audio_lengths = [] + for filepath in filepaths: + audio, sr = sf.read(filepath) + if sr != 16000: + audio = librosa.core.resample(audio, orig_sr=sr, target_sr=16000) + audio_tensor = torch.tensor(audio, dtype=torch.float32, device=device) + audio_batch.append(audio_tensor) + audio_lengths.append(audio_tensor.size(0)) + + batch_audio_lens = torch.tensor(audio_lengths, device=device).long() + max_audio_len = int(batch_audio_lens.max().item()) + audio_batch = stack_tensors(audio_batch, max_lens=[max_audio_len]) + + _, speaker_embeddings = speaker_verification_model.forward( + input_signal=audio_batch, input_signal_length=batch_audio_lens + ) + + return speaker_embeddings + + +def transcribe_with_whisper( + audio_filepath, language, whisper_processor, whisper_model, device, normalizer: Optional[Normalizer] = None +): + speech_array, sampling_rate = librosa.load(audio_filepath, sr=16000) + forced_decoder_ids = ( + whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None + ) + inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features + inputs = inputs.to(device) + with torch.no_grad(): + predicted_ids = whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) + transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) + result = transcription[0] + if normalizer is not None: + result = normalizer.normalize(result) + return result diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py old mode 100644 new mode 100755 index baf9a1648282..d133a02dd4e7 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -31,9 +31,9 @@ from nemo.core.neural_types.elements import ( AudioSignal, EncodedRepresentation, - Index, LengthsType, MelSpectrogramType, + TokenIndex, VoidType, ) from nemo.core.neural_types.neural_type import NeuralType @@ -54,6 +54,19 @@ HAVE_FSSPEC = False +from contextlib import contextmanager + + +@contextmanager +def default_precision(dtype=torch.float32): + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(default_dtype) + + def get_padding(kernel_size: int, dilation: int = 1) -> int: return (kernel_size * dilation - dilation) // 2 @@ -407,40 +420,41 @@ def forward(self, x, l2_norm=False): Shapes: - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` """ - x.squeeze_(1) - # if you torch spec compute it otherwise use the mel spec computed by the AP - if self.use_torch_spec: - x = self.torch_spec(x) + with default_precision(torch.float32): + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) - if self.log_input: - x = (x + 1e-6).log() - x = self.instancenorm(x).unsqueeze(1) + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) - x = self.conv1(x) - x = self.relu(x) - x = self.bn1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) - x = x.reshape(x.size()[0], -1, x.size()[-1]) + x = x.reshape(x.size()[0], -1, x.size()[-1]) - w = self.attention(x) + w = self.attention(x) - if self.encoder_type == "SAP": - x = torch.sum(x * w, dim=2) - elif self.encoder_type == "ASP": - mu = torch.sum(x * w, dim=2) - sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) - x = torch.cat((mu, sg), 1) + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) - x = x.view(x.size()[0], -1) - x = self.fc(x) + x = x.view(x.size()[0], -1) + x = self.fc(x) - if l2_norm: - x = torch.nn.functional.normalize(x, p=2, dim=1) + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) return x def get_torch_mel_spectrogram_class(self, audio_config): @@ -498,6 +512,7 @@ def __init__( kernel_size: int, stride: int = 1, groups: int = None, + activation: Optional[str] = None, trim_right_ratio: int = 1, bias=True, ): @@ -510,6 +525,11 @@ def __init__( self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias) + if activation is not None: + self.activation = CodecActivation(activation=activation, channels=out_channels) + else: + self.activation = nn.Identity() + kernel_size = self.conv.kernel_size[0] stride = self.conv.stride[0] padding_total = kernel_size - stride @@ -538,6 +558,7 @@ def forward(self, inputs, input_len): # unpad end = hidden_states.shape[-1] - self.padding_right hidden_states = hidden_states[..., self.padding_left : end] + hidden_states = self.activation(hidden_states) # mask hidden_states = mask_sequence_tensor(hidden_states, input_len) return hidden_states @@ -554,6 +575,7 @@ def __init__( stride: int = 1, dilation: int = 1, groups: int = 1, + activation: Optional[str] = None, pad_mode: str = "zeros", extra_pad_mode: str = "constant", bias: bool = True, @@ -578,6 +600,10 @@ def __init__( bias=bias, padding_mode=pad_mode, ) + if activation is not None: + self.activation = CodecActivation(activation=activation, channels=out_channels) + else: + self.activation = nn.Identity() kernel_size = self.conv.kernel_size[0] stride = torch.tensor(self.conv.stride[0], dtype=torch.int64) @@ -602,12 +628,12 @@ def _get_extra_padding_for_conv1d( hidden_states: torch.Tensor, ) -> torch.Tensor: """See `pad_for_conv1d`.""" - length = hidden_states.shape[-1] - n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 - n_frames = torch.ceil(n_frames).to(torch.int64) - 1 - ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total - - return ideal_length - length + with default_precision(torch.float32): + length = hidden_states.shape[-1] + n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 + n_frames = torch.ceil(n_frames).to(torch.int64) - 1 + ideal_length = (n_frames * self.stride).long() + self.kernel_size - self.padding_total + return (ideal_length - length).long() @staticmethod # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._pad1d @@ -635,6 +661,7 @@ def forward(self, inputs, input_len): # Left padding for causal hidden_states = self._pad1d(inputs, (self.padding_total, extra_padding), mode=self.extra_pad_mode) hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) # mask output hidden_states = mask_sequence_tensor(hidden_states, input_len) @@ -652,6 +679,7 @@ def __init__( dilation: int = 1, padding: Optional[int] = None, pad_mode: str = "reflect", + activation: Optional[str] = None, ): super().__init__() if not padding: @@ -666,6 +694,10 @@ def __init__( padding_mode=pad_mode, ) self.conv = nn.utils.parametrizations.weight_norm(conv) + if activation is not None: + self.activation = CodecActivation(activation=activation, channels=out_channels) + else: + self.activation = torch.nn.Identity() @property def input_types(self): @@ -686,12 +718,21 @@ def remove_weight_norm(self): @typecheck() def forward(self, inputs, input_len): out = self.conv(inputs) + out = self.activation(out) out = mask_sequence_tensor(out, input_len) return out class ConvTranspose1dNorm(NeuralModule): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, groups: int = 1): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + activation: Optional[str] = None, + ): super().__init__() padding, output_padding = get_up_sample_padding(kernel_size, stride) conv = nn.ConvTranspose1d( @@ -706,6 +747,11 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride ) self.conv = nn.utils.parametrizations.weight_norm(conv) + if activation is not None: + self.activation = CodecActivation(activation=activation, channels=out_channels) + else: + self.activation = nn.Identity() + @property def input_types(self): return { @@ -725,6 +771,7 @@ def remove_weight_norm(self): @typecheck() def forward(self, inputs, input_len): out = self.conv(inputs) + out = self.activation(out) out = mask_sequence_tensor(out, input_len) return out @@ -999,7 +1046,8 @@ def output_types(self): def forward(self, audio): scores_list = [] fmap_list = [] - spec = self.compute_stft(audio) + # run spec compute on fp32 and convert out to the model training type + spec = self.compute_stft(audio.float()).to(audio.dtype) for band, disc in zip(self.stft_bands, self.discriminators): spec_band = spec[:, :, :, band[0] : band[1]] score, fmap = disc(spec=spec_band) @@ -1105,6 +1153,16 @@ def forward(self, audio_real, audio_gen): class VectorQuantizerBase(NeuralModule, ABC): + @property + @abstractmethod + def num_codebooks(self) -> int: + pass + + @property + @abstractmethod + def codebook_size(self) -> int: + pass + @property def input_types(self): return { @@ -1116,7 +1174,7 @@ def input_types(self): def output_types(self): return { "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), - "indices": NeuralType(('D', 'B', 'T'), Index()), + "indices": NeuralType(('D', 'B', 'T'), TokenIndex()), } @typecheck() @@ -1129,7 +1187,7 @@ def forward(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch. "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + output_types={"indices": NeuralType(('D', 'B', 'T'), TokenIndex())}, ) @abstractmethod def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: @@ -1137,7 +1195,7 @@ def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: @typecheck( input_types={ - "indices": NeuralType(('D', 'B', 'T'), Index()), + "indices": NeuralType(('D', 'B', 'T'), TokenIndex()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ @@ -1184,6 +1242,11 @@ def __init__(self, num_levels: List[int], eps: float = 1e-3): logging.debug('\tcodebook_size: %s', self.codebook_size) logging.debug('\teps: %s', self.eps) + @property + def num_codebooks(self): + """Returns the number of codebooks.""" + return 1 + @property def codebook_size(self): """Returns the size of the corresponding codebook.""" @@ -1249,7 +1312,7 @@ def compress(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tenso "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"codes": NeuralType(('B', 'D', 'T'), Index())}, + output_types={"codes": NeuralType(('B', 'D', 'T'), TokenIndex())}, ) def inputs_to_codes(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: # apply compression @@ -1311,7 +1374,7 @@ def forward( "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), }, - output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + output_types={"indices": NeuralType(('D', 'B', 'T'), TokenIndex())}, ) def encode(self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor: """Convert a continuous code vector to a single index.""" @@ -1320,7 +1383,7 @@ def encode(self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None) @typecheck( input_types={ - "indices": NeuralType(('D', 'B', 'T'), Index()), + "indices": NeuralType(('D', 'B', 'T'), TokenIndex()), "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), }, output_types={ @@ -1380,19 +1443,19 @@ def __init__(self, num_groups: int, num_levels_per_group: List[int], **kwargs): logging.debug('\tcodebook_dim_per_group: %d', self.codebook_dim_per_group) @property - def codebook_dim(self): - """Input vector dimension.""" - return self.codebook_dim_per_group * self.num_groups + def num_codebooks(self): + """Returns the number of codebooks.""" + return self.num_groups @property - def codebook_size_per_group(self): - """Returns the size of the implicit codebook for each group.""" + def codebook_size(self): + """Returns the size of the codebook for each group.""" return self.fsqs[0].codebook_size @property - def codebook_size(self): - """Returns the size of the implicit codebook.""" - return self.codebook_size_per_group**self.num_groups + def codebook_dim(self): + """Input vector dimension.""" + return self.codebook_dim_per_group * self.num_groups @typecheck() def forward(self, inputs, input_len): @@ -1419,7 +1482,7 @@ def forward(self, inputs, input_len): "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + output_types={"indices": NeuralType(('D', 'B', 'T'), TokenIndex())}, ) def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: """Input is split into groups, each group is encoded separately, then the results are concatenated.""" @@ -1437,7 +1500,7 @@ def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: @typecheck( input_types={ - "indices": NeuralType(('D', 'B', 'T'), Index()), + "indices": NeuralType(('D', 'B', 'T'), TokenIndex()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ @@ -1458,6 +1521,33 @@ def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor return dequantized + @typecheck( + input_types={ + "codes": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "indices": NeuralType(('B', 'D', 'T'), TokenIndex()), + }, + ) + def codes_to_indices(self, codes: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: + """Converts a code vector to indices.""" + codes_rearrange = rearrange(codes, 'B D T -> D B T') + codes_grouped = codes_rearrange.chunk(self.num_groups, dim=0) + indices = [] + + for codes_group, fsq_group in zip(codes_grouped, self.fsqs): + codes_group_rearrange = rearrange(codes_group, 'D B T -> B D T') + # [B, T] + indices_group = fsq_group.codes_to_indices(codes=codes_group_rearrange) + indices_group = mask_sequence_tensor(indices_group, input_len) + indices.append(indices_group) + + # concatenate along the feature dimension + indices = torch.stack(indices, dim=1) + + return indices + class ResidualBlock(NeuralModule): """ @@ -1534,6 +1624,72 @@ def forward(self, inputs, input_len): return out +class ResidualBlockV2(NeuralModule): + """ + Residual block which applies activation to output instead of input. + + Args: + channels: Input dimension. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + activation: Activation to apply in between residual convolutions. + is_causal: Whether to use causal convolutions. + pad_mode: Type of padding to use for conv1d layers. + See https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + """ + + def __init__( + self, + channels: int, + filters: int, + kernel_size: int = 3, + activation: str = "lrelu", + is_causal: bool = False, + pad_mode: str = "reflect", + ): + super(ResidualBlockV2, self).__init__() + + if not is_causal: + self.input_conv = Conv1dNorm( + in_channels=channels, + out_channels=filters, + kernel_size=kernel_size, + activation=activation, + pad_mode=pad_mode, + ) + self.skip_conv = Conv1dNorm( + in_channels=filters, out_channels=channels, kernel_size=kernel_size, pad_mode=pad_mode + ) + else: + self.input_conv = CausalConv1dNorm( + in_channels=channels, out_channels=filters, kernel_size=kernel_size, activation=activation + ) + self.skip_conv = CausalConv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) + + self.output_activation = CodecActivation(activation=activation, channels=channels) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} + + @property + def output_types(self): + return {"out": NeuralType(('B', 'C', 'T'), EncodedRepresentation())} + + @typecheck() + def forward(self, inputs, input_len): + res = self.input_conv(inputs=inputs, input_len=input_len) + res = self.skip_conv(inputs=res, input_len=input_len) + out = inputs + res + out = self.output_activation(out) + out = mask_sequence_tensor(out, lengths=input_len) + return out + + class HiFiGANResBlock(NeuralModule): """ Residual block wrapper for HiFi-GAN which creates a block for multiple dilations. @@ -1765,7 +1921,8 @@ def forward(self, audio, audio_len): out = res_layer(inputs=out, input_len=encoded_len) out = act(out) - encoded_len = encoded_len // down_sample_rate + with default_precision(torch.float32): + encoded_len = (encoded_len // down_sample_rate).long() # [B, 2 * C, T / down_sample_rate] out = down_sample_conv(inputs=out, input_len=encoded_len) @@ -1886,7 +2043,8 @@ def forward(self, audio, audio_len): out = res_layer(inputs=out, input_len=encoded_len) out = act(out) - encoded_len = encoded_len // down_sample_rate + with default_precision(torch.float32): + encoded_len = (encoded_len // down_sample_rate).long() # [B, 2 * C, T / down_sample_rate] out = down_sample_conv(inputs=out, input_len=encoded_len) @@ -2012,7 +2170,8 @@ def forward(self, inputs, input_len): for act, res_layer, up_sample_conv, up_sample_rate in zip( self.activations, self.res_layers, self.up_sample_conv_layers, self.up_sample_rates ): - audio_len = audio_len * up_sample_rate + with default_precision(torch.float32): + audio_len = (audio_len * up_sample_rate).long() out = act(out) # [B, C / 2, T * up_sample_rate] out = up_sample_conv(inputs=out, input_len=audio_len) @@ -2207,6 +2366,61 @@ def forward(self, audio, audio_len): return spec, spec_len +class STFTProcessor(NeuralModule): + """ + Interface for computing log magnitude STFT features. + + Args: + n_fft: Size of Fourier transform + win_length: The size of the sliding window frames for windowing and STFT. + hop_length: The distance between neighboring sliding window frames + log_guard: Value to add to magnitude STFT before taking log. + """ + + def __init__(self, n_fft, win_length, hop_length, log_guard=1.0): + super(STFTProcessor, self).__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) + self.log_guard = log_guard + self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), + "spec_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + spec_len = audio_len // self.hop_length + audio_padded = torch.nn.functional.pad(audio, (self.stft_pad_amount, self.stft_pad_amount), "reflect") + # [B, n_fft, T_spec] + fft = torch.stft( + audio_padded, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + return_complex=True, + center=False, + ) + fft_mag = torch.abs(fft) + fft_mag_log = torch.log(fft_mag + self.log_guard) + fft_mag_log = mask_sequence_tensor(fft_mag_log, spec_len) + return fft_mag_log, spec_len + + class ResNetEncoder(NeuralModule): """ Residual network which uses HiFi-GAN residual blocks to encode spectrogram features without changing @@ -2388,3 +2602,518 @@ def forward(self, audio, audio_len): # [B, C, T] encoded = torch.cat(outputs, dim=1) return encoded, spec_len + + +class STFTResidualBlock(NeuralModule): + """ + Block in multi-resolution STFT encoder which adds an STFT resolution to the encoder latent space, after down + sampling the input to match the time resoluton of the STFT features. + + Args: + resolution: STFT resolution, formatted as a 3-tuple (n_fft, hop_length, window_size) + input_dim: Dimension if input latenct features. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + activation: Name of activation function. + down_sample_rate: Down sample factor to reduce input by before adding STFT encoding. + """ + + def __init__( + self, + resolution: Tuple[int], + input_dim: int, + filters: int, + kernel_size: int, + activation: str, + down_sample_rate: int, + pad_mode: str, + ): + super(STFTResidualBlock, self).__init__() + down_sample_kernel_size = down_sample_rate * 2 + 1 + + self.down_sample_rate = down_sample_rate + self.down_sample_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=filters, + kernel_size=down_sample_kernel_size, + stride=self.down_sample_rate, + activation=activation, + pad_mode=pad_mode, + ) + + n_fft, hop_length, win_length = resolution + stft_dim = n_fft // 2 + 1 + self.spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) + self.spec_conv = Conv1dNorm(in_channels=stft_dim, out_channels=filters, kernel_size=kernel_size) + self.spec_act = CodecActivation(activation=activation, channels=filters) + + self.res_block = ResidualBlockV2( + channels=filters, filters=filters, kernel_size=kernel_size, activation=activation, pad_mode=pad_mode + ) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), + "out_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, inputs, input_len, audio, audio_len): + out_len = input_len // self.down_sample_rate + out = self.down_sample_conv(inputs=inputs, input_len=out_len) + + spec, _ = self.spec_processor(audio=audio, audio_len=audio_len) + spec_res = self.spec_conv(inputs=spec, input_len=out_len) + out = out + spec_res + out = self.spec_act(out) + + out = self.res_block(inputs=out, input_len=out_len) + return out, out_len + + +class DownSampleResidualBlock(NeuralModule): + """ + Layer which combines a down sampling layer with a residual block. + + Args: + channels: Input dimension. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + activation: Activation to apply in between residual convolutions. + down_sample_rate: Factor to down sample time dimension by. + """ + + def __init__( + self, + channels: int, + filters: int, + kernel_size: int, + activation: str, + down_sample_rate: int, + pad_mode: str, + ): + super(DownSampleResidualBlock, self).__init__() + down_sample_kernel_size = down_sample_rate * 2 + 1 + + self.down_sample_rate = down_sample_rate + self.down_sample_conv = Conv1dNorm( + in_channels=channels, + out_channels=filters, + kernel_size=down_sample_kernel_size, + stride=self.down_sample_rate, + activation=activation, + pad_mode=pad_mode, + ) + self.res_block = ResidualBlockV2( + channels=filters, filters=filters, kernel_size=kernel_size, activation=activation + ) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} + + @property + def output_types(self): + return { + "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), + "out_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, inputs, input_len): + output_len = input_len // self.down_sample_rate + out = self.down_sample_conv(inputs=inputs, input_len=output_len) + out = self.res_block(inputs=out, input_len=output_len) + return out, output_len + + +class MultiResolutionSTFTEncoder(NeuralModule): + """ + Encoder which computes log magnitude STFT features at several time resolutions and encodes them into a low + frame-rate representation. + + Args: + out_dim: Dimension of encoder output embedding. + resolutions: List of STFT resolutions, formatted as 3-tuples (n_fft, hop_length, window_size) + resolution_filter_list: List the same size as 'resolutions', specifying the number of filters in the residual + block for each STFT resolution. + down_sample_filter_list: List of filters to use for each down sampling block after initial STFT encoding. + down_sample_rate_list: List of rates to use for each down sampling block after initial STFT encoding. + The total down sample rate of the encoder will be 2**(len(resolutions)) * product(down_sample_rate_list) + kernel_size: Kernel size to use in all convolutions. + activation: Name of activation function. + resample_rates: Optional tuple of two integers. If provided, input audio will be resampled from sampling rate + resample_rates[0] to sampling rate resample_rates[1]. + pad_mode: Type of padding to use for conv1d layers. + See https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + """ + + def __init__( + self, + out_dim: int, + resolutions: List[Tuple[int]], + resolution_filter_list: List[int], + down_sample_filter_list: Tuple[int] = (), + down_sample_rate_list: Tuple[int] = (), + kernel_size: int = 3, + activation: str = "lrelu", + resample_rates: Tuple[int] = (), + pad_mode: str = "replicate", + ): + super(MultiResolutionSTFTEncoder, self).__init__() + assert len(resolutions) >= 1 + assert len(resolutions) == len(resolution_filter_list) + + if resample_rates: + if not HAVE_TORCHAUDIO: + raise ValueError("Must install torchaudio for resampling.") + + input_sr, encoder_sr = resample_rates + self.resample = torchaudio.transforms.Resample(input_sr, encoder_sr) + self.resample_length_modifier = encoder_sr / input_sr + else: + self.resample = torch.nn.Identity() + self.resample_length_modifier = 1.0 + + n_fft, hop_length, win_length = resolutions[0] + input_filters = resolution_filter_list[0] + input_dim = n_fft // 2 + 1 + self.pre_spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) + self.pre_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=input_filters, + kernel_size=kernel_size, + activation=activation, + pad_mode=pad_mode, + ) + self.pre_res_block = ResidualBlockV2( + channels=input_filters, + filters=input_filters, + kernel_size=kernel_size, + activation=activation, + pad_mode=pad_mode, + ) + input_dim = input_filters + self.stft_blocks = nn.ModuleList([]) + for resolution, filters in zip(resolutions[1:], resolution_filter_list[1:]): + stft_block = STFTResidualBlock( + resolution=resolution, + input_dim=input_dim, + down_sample_rate=2, + filters=filters, + kernel_size=kernel_size, + activation=activation, + pad_mode=pad_mode, + ) + self.stft_blocks.append(stft_block) + input_dim = filters + + if down_sample_filter_list and not down_sample_rate_list: + down_sample_rate_list = len(down_sample_filter_list) * [2] + + self.down_sample_blocks = nn.ModuleList([]) + for filters, down_sample_rate in zip(down_sample_filter_list, down_sample_rate_list): + down_sample_block = DownSampleResidualBlock( + channels=input_dim, + filters=filters, + down_sample_rate=down_sample_rate, + kernel_size=kernel_size, + activation=activation, + pad_mode=pad_mode, + ) + self.down_sample_blocks.append(down_sample_block) + input_dim = filters + + self.post_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=out_dim, + kernel_size=kernel_size, + pad_mode=pad_mode, + ) + + def remove_weight_norm(self): + self.encoder.remove_weight_norm() + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + audio = self.resample(audio) + audio_len = torch.round(self.resample_length_modifier * audio_len).int() + + encoded, encoded_len = self.pre_spec_processor(audio=audio, audio_len=audio_len) + encoded = self.pre_conv(inputs=encoded, input_len=encoded_len) + encoded = self.pre_res_block(inputs=encoded, input_len=encoded_len) + + for stft_block in self.stft_blocks: + encoded, encoded_len = stft_block(inputs=encoded, input_len=encoded_len, audio=audio, audio_len=audio_len) + + for down_sample_block in self.down_sample_blocks: + encoded, encoded_len = down_sample_block(inputs=encoded, input_len=encoded_len) + + encoded = self.post_conv(inputs=encoded, input_len=encoded_len) + + return encoded, encoded_len + + +class VectorQuantizerIndexConverter(NeuralModule): + """ + Utility for converting indices between two FSQ definitions. + + Example: + + from nemo.collections.tts.models import AudioCodecModel + from nemo.collections.tts.modules.audio_codec_modules import GroupFiniteScalarQuantizer, VectorQuantizerIndexConverter + + audio_file = "/home/audio.wav" + codec_path = "/home/SpectralCodecFps43.nemo" + + device = "cuda:0" + + audio, _ = librosa.load(audio_file, sr=sample_rate) + + audio_tensor = torch.tensor([audio]).to(device) + audio_len_tensor = torch.tensor([audio.shape[0]]).to(device) + + codec_model = AudioCodecModel.restore_from(codec_path, map_location=device) + tokens, token_len = codec_model.encode(audio=audio_tensor, audio_len=audio_len_tensor) + + fsq_new = GroupFiniteScalarQuantizer(num_groups=6, num_levels_per_group=[5, 5, 5, 5]).to(device) + + # vector_quantizer_original has 4 codebooks with 6 levels [5, 5, 5, 5, 5, 5] + # vector_quantizer_new has 6 codebooks with 4 levels [5, 5, 5, 5] + fsq_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=fsq_new + ) + + tokens_new = fsq_converter.convert_original_to_new(audio_tokens=tokens, audio_lens=token_len) + tokens_original = fsq_converter.convert_new_to_original(audio_tokens=tokens_new, audio_lens=token_len) + + """ + + def __init__(self, vector_quantizer_original, vector_quantizer_new): + super().__init__() + self.vector_quantizer_original = vector_quantizer_original + self.vector_quantizer_new = vector_quantizer_new + + # Input [batch, num_codebooks_original, time] + # Output [batch, num_codebooks_new, time] + def convert_original_to_new(self, audio_tokens, audio_lens): + audio_tokens_rearrange = rearrange(audio_tokens, 'B C T -> C B T') + audio_codes = self.vector_quantizer_original.decode(indices=audio_tokens_rearrange, input_len=audio_lens) + audio_tokens_new = self.vector_quantizer_new.codes_to_indices(codes=audio_codes, input_len=audio_lens) + return audio_tokens_new + + # Input [batch, num_codebooks_new, time] + # Output [batch, num_codebooks_original, time] + def convert_new_to_original(self, audio_tokens, audio_lens): + audio_tokens_rearrange = rearrange(audio_tokens, 'B C T -> C B T') + audio_codes = self.vector_quantizer_new.decode(indices=audio_tokens_rearrange, input_len=audio_lens) + audio_tokens_original = self.vector_quantizer_original.codes_to_indices( + codes=audio_codes, input_len=audio_lens + ) + return audio_tokens_original + + +class ResNetDecoder(NeuralModule): + """ + A residual decoder designed for low-latency. Most processing is done at a low frame-rate (e.g. 50 FPS), while + minimizing the size of the network which upsamples to the final waveform. + + Args: + input_dim: Dimension of decoder input. + input_filters: Size of the first CNN layer applied to the decoder input. + pre_up_sample_rates: Up sample rates to apply prior to main decoder network. + pre_up_sample_filters: Size of residual blocks in first up sampling blocks. + n_hidden_layers: Number of residual blocks in the main decoder network, which processes the latent space at + low frame-rate. + hidden_filters: Size of each rsidual block in the main decoder network. + resblock_up_sample_rates: Up sample rates to apply after main decoder network. + resblock_up_sample_filters: Size of residual blocks in final up sampling blocks. + resblock_up_sample_kernel_size: Kernel size to use in final up sampling blocks. + kernel_size: Kernel size to use in all other CNN layers. + activation: Name of activation to use in residual blocks. + is_causal: Whether to make the decoder causal. + pad_mode: Type of padding to use for conv1d layers. + See https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + """ + + def __init__( + self, + input_dim: int, + input_filters: int, + pre_up_sample_rates: List[int], + pre_up_sample_filters: List[int], + n_hidden_layers: int, + hidden_filters: int, + resblock_up_sample_rates: List[int], + resblock_up_sample_filters: List[int], + resblock_up_sample_kernel_size: int = 7, + kernel_size: int = 3, + activation: str = "half_snake", + is_causal: bool = False, + pad_mode: str = "replicate", + ): + super().__init__() + + assert len(pre_up_sample_rates) == len(pre_up_sample_filters) + assert len(resblock_up_sample_rates) == len(resblock_up_sample_filters) + + if not is_causal: + conv_class = Conv1dNorm + else: + conv_class = CausalConv1dNorm + + if not is_causal: + conv_transpose_class = ConvTranspose1dNorm + else: + conv_transpose_class = CausalConvTranspose1dNorm + + self.pre_conv = conv_class( + in_channels=input_dim, + out_channels=input_filters, + kernel_size=kernel_size, + ) + + in_channels = input_filters + self.pre_up_sample_rates = pre_up_sample_rates + self.pre_resblocks = nn.ModuleList([]) + self.pre_up_sample_layers = nn.ModuleList([]) + for up_sample_rate, filters in zip(self.pre_up_sample_rates, pre_up_sample_filters): + res_block = ResidualBlockV2( + channels=in_channels, + filters=(2 * in_channels), + kernel_size=kernel_size, + activation=activation, + is_causal=is_causal, + pad_mode=pad_mode, + ) + self.pre_resblocks.append(res_block) + conv = conv_transpose_class( + in_channels=in_channels, + out_channels=filters, + kernel_size=(2 * up_sample_rate), + stride=up_sample_rate, + activation=activation, + ) + self.pre_up_sample_layers.append(conv) + + in_channels = filters + + self.conv_layers = nn.ModuleList( + [ + ResidualBlockV2( + channels=in_channels, + filters=hidden_filters, + kernel_size=kernel_size, + activation=activation, + is_causal=is_causal, + pad_mode=pad_mode, + ) + for _ in range(n_hidden_layers) + ] + ) + + self.resblock_up_sample_rates = resblock_up_sample_rates + self.resblock_up_sample_layers = nn.ModuleList([]) + self.resblocks = nn.ModuleList([]) + for up_sample_rate, filters in zip(self.resblock_up_sample_rates, resblock_up_sample_filters): + conv = conv_transpose_class( + in_channels=in_channels, + out_channels=filters, + kernel_size=(2 * up_sample_rate), + stride=up_sample_rate, + activation=activation, + ) + self.resblock_up_sample_layers.append(conv) + res_block = ResidualBlockV2( + channels=filters, + filters=(2 * filters), + kernel_size=resblock_up_sample_kernel_size, + activation=activation, + is_causal=is_causal, + pad_mode=pad_mode, + ) + self.resblocks.append(res_block) + in_channels = filters + + self.post_conv = conv_class( + in_channels=in_channels, out_channels=1, kernel_size=resblock_up_sample_kernel_size, pad_mode=pad_mode + ) + + self.out_activation = ClampActivation(clamp_training=False) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T_encoded'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, inputs, input_len): + + out = self.pre_conv(inputs=inputs, input_len=input_len) + + audio_len = input_len + for pre_up_sample_rate, pre_up_sample_layer, pre_resblock in zip( + self.pre_up_sample_rates, self.pre_up_sample_layers, self.pre_resblocks + ): + out = pre_resblock(inputs=out, input_len=audio_len) + audio_len = pre_up_sample_rate * audio_len + out = pre_up_sample_layer(inputs=out, input_len=audio_len) + + for conv in self.conv_layers: + out = conv(inputs=out, input_len=audio_len) + + for resblock_up_sample_rate, resblock_up_sample_layer, resblock in zip( + self.resblock_up_sample_rates, self.resblock_up_sample_layers, self.resblocks + ): + audio_len = resblock_up_sample_rate * audio_len + out = resblock_up_sample_layer(inputs=out, input_len=audio_len) + out = resblock(inputs=out, input_len=audio_len) + + out = self.post_conv(inputs=out, input_len=audio_len) + out = rearrange(out, 'B 1 T -> B T') + audio = self.out_activation(out) + audio = mask_sequence_tensor(audio, audio_len) + + return audio, audio_len diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index e9a1556ab700..db32566a536e 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -59,7 +59,6 @@ from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType, LossType, VoidType from nemo.core.neural_types.neural_type import NeuralType from nemo.utils import logging -from nemo.utils.decorators import experimental class SEANetResnetBlock(NeuralModule): @@ -537,7 +536,6 @@ def _mask_3d(tensor: Tensor, lengths: Tensor): return tensor * mask -@experimental class EuclideanCodebook(NeuralModule): """ Codebook with Euclidean distance. @@ -739,6 +737,16 @@ def __init__( ] ) + @property + def num_codebooks(self): + """Returns the number of codebooks.""" + return len(self.codebooks) + + @property + def codebook_size(self): + """Returns the size of the codebook for each group.""" + return self.codebooks[0].codebook_size + # Override output types, since this quantizer returns commit_loss @property def output_types(self): @@ -837,7 +845,7 @@ class GroupResidualVectorQuantizer(VectorQuantizerBase): def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwargs): super().__init__() - self.num_codebooks = num_codebooks + self._num_codebooks = num_codebooks self.num_groups = num_groups self.codebook_dim = codebook_dim @@ -858,6 +866,16 @@ def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwa logging.debug('\tnum_codebooks_per_group: %d', self.num_codebooks_per_group) logging.debug('\tcodebook_dim_per_group: %d', self.codebook_dim_per_group) + @property + def num_codebooks(self): + """Returns the number of codebooks.""" + return self._num_codebooks + + @property + def codebook_size(self): + """Returns the size of the codebook for each group.""" + return self.rvqs[0].codebook_size + @property def num_codebooks_per_group(self): """Number of codebooks for each group.""" diff --git a/nemo/collections/tts/modules/fcd_metric.py b/nemo/collections/tts/modules/fcd_metric.py new file mode 100644 index 000000000000..42751897c906 --- /dev/null +++ b/nemo/collections/tts/modules/fcd_metric.py @@ -0,0 +1,279 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This is an experimental metric. It measures the Frechet Distance between distributions of generated and real +codec frames. The distance is measured in the embedding space of the codec. We get the embeddings +by dequantizing codec frames. + +Like all FD metrics, the metric operates on a dataset level. A large number of real and generated frames are needed for the metric to be reliable -- on the order of tens of thousands. + +The frames are currently considered independently, i.e. temporal relationships between are not captured (though this might +be useful to explore). +""" + +import numpy as np +import torch +from torch import Tensor, nn +from torchmetrics import Metric + +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.tts.models import AudioCodecModel +from nemo.utils import logging + + +class CodecEmbedder(nn.Module): + """ + Embeds audio codec codes into the codec's continuous embedding space. + Accepts as input either a batch of codes or a path to an audio file. + """ + + def __init__(self, codec: AudioCodecModel): + super().__init__() + self.codec = codec + + def codes_to_embedding(self, x: Tensor, x_len: Tensor) -> Tensor: + """ + Embeds a batch of audio codec codes into the codec's continuous embedding space. + """ + # x: (B, C, T) + # x_len: (B,) + return self.codec.dequantize(tokens=x, tokens_len=x_len) + + def encode_from_file(self, audio_path: str) -> Tensor: + """ + Encodes an audio file into audio codec codes. + """ + audio_segment = AudioSegment.from_file(audio_path, target_sr=self.codec.sample_rate) + assert np.issubdtype(audio_segment.samples.dtype, np.floating) + audio_min = audio_segment.samples.min() + audio_max = audio_segment.samples.max() + eps = 0.01 # certain ways of normalizing audio can result in samples that are slightly outside of [-1, 1] + if audio_min < (-1.0 - eps) or audio_max > (1.0 + eps): + logging.warning(f"Audio samples are not normalized: min={audio_min}, max={audio_max}") + samples = torch.tensor(audio_segment.samples, device=self.codec.device).unsqueeze(0) + audio_len = torch.tensor(samples.shape[1], device=self.codec.device).unsqueeze(0) + codes, codes_len = self.codec.encode(audio=samples, audio_len=audio_len) + return codes, codes_len + + +class FrechetCodecDistance(Metric): + """ + Computes the Frechet Codec Distance between two distributions of audio codec frames (real and generated). + This is done in codec embedding space, one frame at a time. We name this metric the Frechet Codec Distance (FCD). + """ + + """ + Parts of this are based on the following implementation of FID (Frechet Inception Distance) on images: + + https://github.com/pytorch/torcheval/blob/main/torcheval/metrics/image/fid.py + + # Copyright (c) Meta Platforms, Inc. and affiliates. + # All rights reserved. + # + # This source code is licensed under the BSD-style license found in the + # LICENSE file in the root directory of this source tree. + + Contents of original LICENSE file: + + # BSD License + # + # For torcheval software + # + # Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. + # + # Redistribution and use in source and binary forms, with or without modification, + # are permitted provided that the following conditions are met: + # + # * Redistributions of source code must retain the above copyright notice, this + # list of conditions and the following disclaimer. + # + # * Redistributions in binary form must reproduce the above copyright notice, + # this list of conditions and the following disclaimer in the documentation + # and/or other materials provided with the distribution. + # + # * Neither the name Meta nor the names of its contributors may be used to + # endorse or promote products derived from this software without specific + # prior written permission. + # + # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR + # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON + # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + """ + is_differentiable = False + higher_is_better = False + full_state_update = False + + def __init__( + self, + codec, + feature_dim: int, + ) -> None: + """ + Computes the Frechet Codec Distance between two distributions of audio codec codes (real and generated). + The original paper (FID on images): https://arxiv.org/pdf/1706.08500.pdf + + Args: + codec (AudioCodecModel): The codec model to use. + feature_dim (int): The number of features in the codec embedding space (usually 4*num_codebooks) + """ + super().__init__() + + # Set the model and put it in evaluation mode + self.model = CodecEmbedder(codec) + self.model.eval() + self.model.requires_grad_(False) + + # Initialize state variables used to compute FCD + self.add_state("real_sum", default=torch.zeros(feature_dim), dist_reduce_fx="sum") + self.add_state("real_cov_sum", default=torch.zeros((feature_dim, feature_dim)), dist_reduce_fx="sum") + self.add_state("fake_sum", default=torch.zeros(feature_dim), dist_reduce_fx="sum") + self.add_state("fake_cov_sum", default=torch.zeros((feature_dim, feature_dim)), dist_reduce_fx="sum") + self.add_state("num_real_frames", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") + self.add_state("num_fake_frames", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") + + def update_from_audio_file(self, audio_path: str, is_real: bool) -> Tensor: + """ + Takes a path to an audio file, embeds it, and updates the FCD metric. + """ + codes, codes_len = self.model.encode_from_file(audio_path) + self.update(codes, codes_len, is_real) + + def update(self, codes: Tensor, codes_len: Tensor, is_real: bool): + """ + Update the states with a batch of real or fake codes. + Takes pre-computed codec codes, embeds them, and updates the FCD metric. + + Args: + codes (Tensor): A batch of codec frames of shape (B, C, T). + codes_len (Tensor): A batch of lengths of the codec frames of shape (B,). + is_real (Boolean): Denotes if samples are real or not. + """ + assert codes.ndim == 3 + + if codes.numel() == 0: + logging.warning("FCD metric received an empty batch of codes - skipping update") + return + + if codes.shape[1] != self.model.codec.num_codebooks: + logging.warning( + f"FCD metric received a batch of codes of shape {codes.shape}, but the model has {self.model.codec.num_codebooks} codebooks - skipping update" + ) + return + + # Dequantize the codes to a continuous representation + embeddings = self.model.codes_to_embedding( + codes, codes_len + ) # B, E, T where E is the codec's embedding dimension, usually 4*num_codebooks + + # keep only the valid frames + valid_frames = [] + for i in range(codes.shape[0]): + valid_frames.append(embeddings[i, :, : codes_len[i]].T) # T', E + embeddings = torch.cat(valid_frames, dim=0) # total_valid_frames, E + valid_frame_count = embeddings.shape[0] + + # Update the state variables used to compute FCD + if is_real: + self.num_real_frames += valid_frame_count + self.real_sum += torch.sum(embeddings, dim=0) + self.real_cov_sum += torch.matmul(embeddings.T, embeddings) + else: + self.num_fake_frames += valid_frame_count + self.fake_sum += torch.sum(embeddings, dim=0) + self.fake_cov_sum += torch.matmul(embeddings.T, embeddings) + + return self + + def compute(self) -> Tensor: + """ + Compute the FCD. + + Returns: + tensor: The FCD. + """ + + # If the user has not already updated with at lease one + # sample from each distribution, then we raise an Error. + if (self.num_real_frames == 0) or (self.num_fake_frames == 0): + logging.warning( + "Computing FD requires at least 1 real frame and 1 fake frame," + f"but currently running with {self.num_real_frames} real frames and {self.num_fake_frames} fake frames." + "Returning 0.0" + ) + return torch.tensor(0.0, device=self.device) + + # Compute the mean activations for each distribution + real_mean = (self.real_sum / self.num_real_frames).unsqueeze(0) + fake_mean = (self.fake_sum / self.num_fake_frames).unsqueeze(0) + + # Compute the covariance matrices for each distribution + real_cov_num = self.real_cov_sum - self.num_real_frames * torch.matmul(real_mean.T, real_mean) + real_cov = real_cov_num / (self.num_real_frames - 1) + fake_cov_num = self.fake_cov_sum - self.num_fake_frames * torch.matmul(fake_mean.T, fake_mean) + fake_cov = fake_cov_num / (self.num_fake_frames - 1) + + # Compute the Frechet Distance between the distributions + fd = self.calculate_frechet_distance(real_mean.squeeze(), real_cov, fake_mean.squeeze(), fake_cov) + # FD should be non-negative but due to numerical errors, it can be slightly negative + # Have seen -0.0011 in the past + if fd < -0.005: + logging.warning(f"FCD is negative, which is unexpected: {fd}") + return torch.clamp(fd, min=0.0) + + def calculate_frechet_distance( + self, + mu1: Tensor, + sigma1: Tensor, + mu2: Tensor, + sigma2: Tensor, + ) -> Tensor: + """ + Calculate the Frechet Distance between two multivariate Gaussian distributions. + + Args: + mu1 (Tensor): The mean of the first distribution. Shape: (feature_dim,) + sigma1 (Tensor): The covariance matrix of the first distribution. Shape: (feature_dim, feature_dim) + mu2 (Tensor): The mean of the second distribution. Shape: (feature_dim,) + sigma2 (Tensor): The covariance matrix of the second distribution. Shape: (feature_dim, feature_dim) + + Returns: + tensor: The Frechet Distance between the two distributions. + """ + # Compute the squared distance between the means + mean_diff = mu1 - mu2 + mean_diff_squared = mean_diff.square().sum(dim=-1) + + # Calculate the sum of the traces of both covariance matrices + trace_sum = sigma1.trace() + sigma2.trace() + + # Compute the eigenvalues of the matrix product of the real and fake covariance matrices + sigma_mm = torch.matmul(sigma1, sigma2) + eigenvals = torch.linalg.eigvals(sigma_mm) + + # Take the square root of each eigenvalue and take its sum + sqrt_eigenvals_sum = eigenvals.sqrt().real.sum(dim=-1) + + # Calculate the FCD using the squared distance between the means, + # the sum of the traces of the covariance matrices, and the sum of the square roots of the eigenvalues + fcd = mean_diff_squared + trace_sum - 2 * sqrt_eigenvals_sum + + return fcd diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py new file mode 100644 index 000000000000..8569b691242f --- /dev/null +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -0,0 +1,253 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from enum import Enum + +import torch +from torch import Tensor + +from nemo.collections.tts.modules import transformer_2501 +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.core.classes.module import NeuralModule +from nemo.utils.enum import PrettyStrEnum + + +class LocalTransformerType(PrettyStrEnum): + """ + Enum for the type of local transformer to use in the MagpieTTS model. + These strings are the values allowed in the YAML config file. + """ + + NO_LT = "none" + AR = "autoregressive" + MASKGIT = "maskgit" + + +class EOSDetectionMethod(PrettyStrEnum): + """ + Enum for the EOS detection method to use in the MagpieTTS model. + These strings are the values allowed in the YAML config file. + """ + + ARGMAX_ANY = "argmax_any" + ARGMAX_OR_MULTINOMIAL_ANY = "argmax_or_multinomial_any" + ARGMAX_ALL = "argmax_all" + ARGMAX_OR_MULTINOMIAL_ALL = "argmax_or_multinomial_all" + ARGMAX_ZERO_CB = "argmax_zero_cb" + ARGMAX_OR_MULTINOMIAL_ZERO_CB = "argmax_or_multinomial_zero_cb" + + @staticmethod + def detection_type(detection_method: EOSDetectionMethod): + if detection_method in [EOSDetectionMethod.ARGMAX_ANY, EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ANY]: + return "any" + elif detection_method in [EOSDetectionMethod.ARGMAX_ALL, EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ALL]: + return "all" + elif detection_method in [EOSDetectionMethod.ARGMAX_ZERO_CB, EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ZERO_CB]: + return "zero_cb" + else: + raise ValueError(f"Invalid EOS detection method: {detection_method}") + + @staticmethod + def sampling_type(detection_method: EOSDetectionMethod): + if detection_method in [ + EOSDetectionMethod.ARGMAX_ANY, + EOSDetectionMethod.ARGMAX_ALL, + EOSDetectionMethod.ARGMAX_ZERO_CB, + ]: + return "argmax" + elif detection_method in [ + EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ANY, + EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ALL, + EOSDetectionMethod.ARGMAX_OR_MULTINOMIAL_ZERO_CB, + ]: + return "argmax_or_multinomial" + else: + raise ValueError(f"Invalid EOS detection method: {detection_method}") + + +class SpecialAudioToken(Enum): + """ + Enum for the special tokens to use in the MagpieTTS model. + The special tokens are appended at the end of the codebook after the actual audio codec tokens. + The actual embedding table index is the value below plus the number of codec tokens - do not use the Enum directly. + """ + + AUDIO_BOS = 0 + AUDIO_EOS = 1 + AUDIO_CONTEXT_BOS = 2 + AUDIO_CONTEXT_EOS = 3 + MASK_TOKEN = 4 + # Reserve these values so that if we need to add more special tokens in the future the codebook size will remain the same + RESERVED_1 = 5 + RESERVED_2 = 6 + RESERVED_3 = 7 + + @staticmethod + def get_index(token: SpecialAudioToken, base_codebook_size: int): + """ + Returns the index of the special token in the embedding table. + """ + return base_codebook_size + token.value + + @staticmethod + def get_forbidden_tokens(base_codebook_size: int, forbid_audio_eos: bool = False) -> list[int]: + """ + Returns a list of token indices that should not be sampled or returned to user. + Args: + base_codebook_size (int): The size of the codec codebook (which is the first part of the embedding table). + forbid_audio_eos (bool): Whether AUDIO_EOS should be forbidden. Default: False (i.e. allowed). + """ + all_special_tokens = list(SpecialAudioToken) + if not forbid_audio_eos: + all_special_tokens.remove(SpecialAudioToken.AUDIO_EOS) + return [SpecialAudioToken.get_index(token, base_codebook_size) for token in all_special_tokens] + + +def cosine_schedule(x: torch.Tensor): + """ + Maps input values from [0, 1] to [1, 0] using the first quadrant of the cosine function. + Used for MaskGit mask scheduling. + """ + return torch.cos(x * (torch.pi / 2)) + + +def build_vocabs(subword_vocab: dict, subword_padding_idx: int, special_vocab: dict = None) -> tuple[dict, dict]: + """ + Builds the character vocabulary and the mapping from subword ids to character ids. + Args: + subword_vocab (dict): A dictionary of subword vocab items. Eg. + tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name) + subword_vocab = tokenizer.vocab + subword_padding_idx (int): The padding index for the subword vocabulary. + special_vocab (dict): items of special token dictionary (usually BOS, EOS) + eg. special_vocab = {'': 0, '': 1} + Returns: + subword_id_to_char_ids: A dictionary mapping subword ids to character ids. + char_vocab: A dictionary mapping character ids to their corresponding characters. + """ + org_char_vocab = {subword: subword_id for subword, subword_id in subword_vocab.items() if len(subword) == 1} + + # Add special tokens directly to char vocab + if special_vocab is not None: + for special_token, special_token_id in special_vocab.items(): + if special_token in org_char_vocab: + raise ValueError(f"Special token {special_token} already exists in the character vocabulary.") + org_char_vocab[special_token] = special_token_id + + sorted_char_vocab = dict(sorted(org_char_vocab.items(), key=lambda x: x[1])) + char_vocab = {k: i for i, (k, _) in enumerate(sorted_char_vocab.items())} + assert sorted(char_vocab.values()) == list(range(len(char_vocab))) + subword_id_to_char_ids = { + subword_id: tuple(char_vocab[char] for char in subword) for subword, subword_id in subword_vocab.items() + } + + # Creating mapping from subword ids of special tokens to their char ids + if special_vocab is not None: + for special_token, special_token_id in special_vocab.items(): + if special_token in subword_id_to_char_ids: + raise ValueError(f"Special token {special_token} already exists in the subword id Vocabulary.") + subword_id_to_char_ids[special_token_id] = (char_vocab[special_token],) + + assert max(subword_id_to_char_ids) == len(subword_id_to_char_ids) - 1 + + # Always add padding token to the end of the vocab (this is the convention used in the original code) + subword_id_to_char_ids[subword_padding_idx] = (len(char_vocab),) + + return subword_id_to_char_ids, char_vocab + + +class CharAwareSubwordEncoder(NeuralModule): + """ + Char-aware subword encoder for the MagpieTTS model. + This module takes subword ids as input, maps them to character ids, and then applies a transformer encoder to the character embeddings. + The output is a tensor of shape (batch_size, max_subword_length, d_embed). + """ + + def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: int, special_vocab: dict = None): + """ + Args: + d_embed (int): The dimension of the embedding. + llm_tokenizer_vocab (dict): A dictionary of subword vocab items. Eg. + tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name) + llm_tokenizer_vocab = tokenizer.vocab + subword_padding_idx (int): The padding index for the subword vocabulary. + special_vocab (dict): items of special token dictionary (usually BOS, EOS) + eg. special_vocab = {'': 30001, '': 30002} + """ + super().__init__() + self.subword_id_to_char_ids, self.char_vocab = build_vocabs( + llm_tokenizer_vocab, subword_padding_idx, special_vocab + ) + self.embed_tokens = torch.nn.Embedding(self.vocab_size + 1, d_embed, padding_idx=self.vocab_size) + self.encoder = transformer_2501.Transformer( + n_layers=1, + d_model=d_embed, + d_ffn=d_embed * 4, + sa_n_heads=8, + kernel_size=1, + max_length_causal_mask=256, + use_learnable_pos_emb=True, + ) + + @property + def vocab_size(self): + return len(self.char_vocab) + + def prepare_inputs(self, subword_ids: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: + device = subword_ids.device + + subword_id_list = torch.masked_select(subword_ids, padding_mask).cpu().tolist() + char_id_list = [list(self.subword_id_to_char_ids[x]) for x in subword_id_list] + + char_lengths = torch.tensor([len(x) for x in char_id_list], dtype=torch.long, device=device) + batch_size = char_lengths.size(0) + + char_ids = torch.full((batch_size, int(char_lengths.max().item())), self.vocab_size, dtype=torch.long) + for i in range(batch_size): + char_ids[i, : char_lengths[i]] = torch.tensor(char_id_list[i]) + char_ids = char_ids.to(device=device) + return char_ids, char_lengths + + def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Tensor: + """ + Args: + subword_ids (Tensor): A tensor of shape (batch_size, max_subword_length) containing the subword ids. + subword_mask (Tensor | None): A tensor of shape (batch_size, max_subword_length) containing the mask for the subword ids. + If None, a mask of ones will be used. + Returns: + Tensor: A tensor of shape (batch_size, max_subword_length, d_embed) containing the subword embeddings. + """ + device = subword_ids.device + if subword_mask is None: + subword_mask = torch.ones_like(subword_ids).bool() + else: + subword_mask = subword_mask.bool() + + if subword_mask.ndim == 3: + subword_mask = subword_mask.squeeze(-1) + + char_ids, char_lengths = self.prepare_inputs(subword_ids, subword_mask) + char_mask = get_mask_from_lengths(char_lengths) + char_emb = self.embed_tokens(char_ids) + # char emb has the shape [B*T, N, channels], where N is the max number of chars tokens decoded from bpe tokens + x = self.encoder(x=char_emb, x_mask=char_mask)['output'] + + # Get average embedding over the chars + mean_emb = ((x / char_mask.unsqueeze(-1).sum(1, keepdim=True)) * char_mask.unsqueeze(-1)).sum(1) + subword_emb = torch.zeros((subword_mask.size(0), subword_mask.size(1), mean_emb.size(-1)), device=device) + subword_emb[subword_mask.unsqueeze(-1).expand(-1, -1, mean_emb.size(-1))] = mean_emb.view(-1) + + return subword_emb diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index dc5debc04f39..ea5a08a0a696 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -82,11 +82,15 @@ def __init__( bias=bias, ) - def forward(self, signal): + def forward(self, signal, signal_mask): + # signal: (B, C, T) + # signal_mask: (B, T) + signal = signal * signal_mask.unsqueeze(1) if self.is_causal: # TODO: maybe replace with identify rather than keep conditional if in forward signal = F.pad(signal, self.causal_padding) conv_signal = self.conv(signal) + conv_signal = conv_signal * signal_mask.unsqueeze(1) return conv_signal @@ -126,12 +130,13 @@ def __init__( self.o_net = ConvolutionLayer(d_ffn, d_model, bias=bias, kernel_size=kernel_size, is_causal=is_causal) self.dropout = torch.nn.Dropout(p_dropout) - def forward(self, x): + def forward(self, x, x_mask): """ x (B, T, C) + x_mask (B, T) """ - x = self.non_linearity(self.proj(x.transpose(1, 2))) - x = self.dropout(self.o_net(x).transpose(1, 2)) + x = self.non_linearity(self.proj(x.transpose(1, 2), x_mask)) + x = self.dropout(self.o_net(x, x_mask).transpose(1, 2)) return x @@ -142,6 +147,7 @@ def __init__( d_model: int, p_dropout: float, is_causal: bool = True, + d_head: Optional[int] = None, ): """ Base Attention parent class. Users should not be instantiating this class, but rather use SelfAttention or @@ -154,10 +160,11 @@ def __init__( d_model (int): Dimension of the model. p_dropout (float): Dropout probability. is_causal (bool): Whether to use causal attention. Only supported when used in SelfAttention. + d_head (int): Head dimension. Defaults to d_model // n_heads. """ super().__init__() assert d_model % n_heads == 0, "d_model % n_head != 0" - self.d_head = d_model // n_heads + self.d_head = d_head if d_head is not None else d_model // n_heads self.n_heads = n_heads self.d_model = d_model self.scale = self.d_head**-0.5 @@ -227,12 +234,27 @@ def attn_naive( # attn_prior or square mask or vanilla attention if attn_prior is not None: - eps = 1e-8 + eps = torch.finfo(attn_prior.dtype).tiny attn_prior = attn_prior[:, :T] # trim for inference - attn_prior = torch.log(attn_prior + eps) - attn_prior = attn_prior[:, None].repeat(1, self.n_heads, 1, 1) - attn_score_log = F.log_softmax(attn_score, dim=-1) + attn_prior - attn_prob = F.softmax(attn_score_log, dim=-1) + attn_prior = attn_prior[:, None] + eps + # Use PyTorch's built-in training flag to branch behavior + if self.training: + attn_prior_log = torch.log(attn_prior) + attn_score_log = F.log_softmax(attn_score, dim=-1) + attn_prior_log + if self.make_prior_window_strict: + # Make sure attention scores are lowest (eps) where prior is zero. + min_score = torch.log(torch.tensor(eps)).to(attn_score_log.device) + attn_score_log = attn_score_log.masked_fill( + attn_prior == 0, min_score + ) # Wherever prior is zero, set scores to eps. + attn_score_log = torch.clamp( + attn_score_log, min=min_score + ) # Make sure scores are not less than eps. + attn_prob = F.softmax(attn_score_log, dim=-1) + else: + attn_prob = F.softmax(attn_score, dim=-1) + attn_prob = attn_prob * attn_prior + attn_prob = attn_prob / (attn_prob.sum(dim=-1, keepdim=True)) # normalize else: attn_prob = F.softmax(attn_score, dim=-1) @@ -333,7 +355,14 @@ def compute_qkv_and_mask( v = torch.cat([self.cache['self_v'], v], dim=1) self.cache['self_k'] = k self.cache['self_v'] = v - mask = query_mask[:, None, :, None] if query_mask is not None else None + + mask = None + if query_mask is not None: + # query_mask is a boolean mask of shape (B, T) + # mask should be of shape (B, 1, T, T) where mask[:,0,i,:] == mask[:,0,:,i] == query_mask + mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2) + mask = mask.unsqueeze(1) + return q, k, v, mask @@ -344,6 +373,8 @@ def __init__( d_model: int, d_memory: int, p_dropout: float, + make_prior_window_strict: bool = False, + d_head: Optional[int] = None, ): """ Implements CrossAttention. See parent class for forward implementation. Must be non-causal. @@ -353,17 +384,19 @@ def __init__( d_model (int): Dimension of the model. d_memory (int): Dimension of the conditioning / cross-attention input. p_dropout (float): Dropout probability. + make_prior_window_strict (bool): Make attention scores lowest where prior is zero. + d_head (int): Head dimension. if None, defaults to d_model // n_heads in parent class. """ super().__init__( n_heads=n_heads, d_model=d_model, p_dropout=p_dropout, is_causal=False, + d_head=d_head, ) - if d_memory is None: - raise ValueError("d_memory must be provided for cross-attention") self.q_net = torch.nn.Linear(d_model, n_heads * self.d_head, bias=False) self.kv_net = torch.nn.Linear(d_memory, 2 * n_heads * self.d_head, bias=False) + self.make_prior_window_strict = make_prior_window_strict def compute_qkv_and_mask( self, @@ -406,10 +439,12 @@ def __init__( has_xattn: bool, xa_d_memory: Optional[int] = None, xa_n_heads: Optional[int] = None, + xa_d_head: Optional[int] = None, is_causal: bool = True, apply_norm_to_cond: bool = True, max_length_causal_mask: int = 4096, conv_non_linearity: Callable = torch.nn.GELU(approximate="tanh"), + make_prior_window_strict: bool = False, ): """ One layer of the Transformer. @@ -422,10 +457,12 @@ def __init__( has_xattn : Whether to use cross attention xa_d_memory : Hidden dimension for cross attention xa_n_heads : Number of attention heads used in cross attention + xa_d_head : Head dimension for cross attention. if None, defaults to d_model // xa_n_heads in Attention class. is_causal : Whether to use causal attention apply_norm_to_cond : Whether to apply normalization to conditioning tensor max_length_causal_mask : Maximum length of causal mask conv_non_linearity : Convolution non-linearity + make_prior_window_strict : Make attention scores lowest where prior is zero. """ super().__init__() self.has_xattn = has_xattn @@ -440,16 +477,18 @@ def __init__( ) if self.has_xattn: - self.apply_norm_to_cond = apply_norm_to_cond self.norm_xattn_query = torch.nn.LayerNorm(d_model, bias=False) self.cross_attention = CrossAttention( n_heads=xa_n_heads, d_model=d_model, d_memory=xa_d_memory, p_dropout=p_dropout, + make_prior_window_strict=make_prior_window_strict, + d_head=xa_d_head, ) - if self.apply_norm_to_cond: + self.norm_xattn_memory = torch.nn.Identity() + if apply_norm_to_cond: self.norm_xattn_memory = torch.nn.LayerNorm(xa_d_memory, bias=False) self.norm_pos_ff = torch.nn.LayerNorm(d_model, bias=False) @@ -510,7 +549,7 @@ def forward( if self.use_cache and self.cache['memory'] is not None: memory = self.cache['memory'] else: - memory = self.norm_xattn_memory(cond) if self.apply_norm_to_cond else cond + memory = self.norm_xattn_memory(cond) if self.use_cache: self.cache['memory'] = memory @@ -524,7 +563,7 @@ def forward( x = x + x_res # mlp final projection - x = x + self.pos_ff(self.norm_pos_ff(x)) + x = x + self.pos_ff(self.norm_pos_ff(x), x_mask) x = x * x_mask.unsqueeze(-1) return { @@ -546,12 +585,14 @@ def __init__( has_xattn: bool = False, xa_d_memory: Optional[int] = None, xa_n_heads: Optional[int] = None, + xa_d_head: Optional[int] = None, is_causal: bool = True, apply_norm_to_cond: bool = True, apply_norm_out: bool = False, max_length_causal_mask: int = 4096, use_learnable_pos_emb: bool = False, conv_non_linearity: Callable = torch.nn.GELU(approximate="tanh"), + make_prior_window_strict: bool = False, ): """ Initializes a stack of transformer layers. Can be used for both encoder and decoder. @@ -567,6 +608,7 @@ def __init__( has_xattn : Whether to use cross attention xa_d_memory : Hidden dimension for cross attention; required if has_xattn is True xa_n_heads : Number of attention heads used in cross attention; required if has_xattn is True + xa_d_head : Head dimension for cross attention. if None, defaults to d_model // xa_n_heads in Attention class. is_causal : Whether to make attention and the convolution feedforward networks causal. apply_norm_to_cond : Whether to apply normalization to conditioning tensor; conditioning tensor being the input to the memory part of cross-attention. @@ -574,27 +616,26 @@ def __init__( max_length_causal_mask : Maximum length of causal mask use_learnable_pos_emb : Whether to add a learnable positionable embedding inside the class conv_non_linearity : Convolution non-linearity + make_prior_window_strict : Make attention scores lowest where prior is zero """ if has_xattn and (xa_d_memory is None or xa_n_heads is None): raise ValueError("It requires that `xa_d_memory` and `xa_n_heads` are specified when `has_xattn` is True!") super().__init__() + self.n_layers = n_layers self.dropout = torch.nn.Dropout(p_dropout) self.p_dropout_out = p_dropout_out + self.dropout_out = torch.nn.Identity() if self.p_dropout_out > 0.0: self.dropout_out = torch.nn.Dropout(self.p_dropout_out) - else: - self.dropout_out = None - self.apply_norm_out = apply_norm_out - if self.apply_norm_out: + self.norm_out = torch.nn.Identity() + if apply_norm_out: self.norm_out = torch.nn.LayerNorm(d_model, bias=False) - else: - self.norm_out = None self.layers = torch.nn.ModuleList() - for _ in range(n_layers): + for _ in range(self.n_layers): self.layers.append( TransformerLayer( d_model=d_model, @@ -605,10 +646,12 @@ def __init__( has_xattn=has_xattn, xa_d_memory=xa_d_memory, xa_n_heads=xa_n_heads, + xa_d_head=xa_d_head, is_causal=is_causal, apply_norm_to_cond=apply_norm_to_cond, max_length_causal_mask=max_length_causal_mask, conv_non_linearity=conv_non_linearity, + make_prior_window_strict=make_prior_window_strict, ) ) @@ -622,7 +665,7 @@ def __init__( self.apply(self._init_weights_gpt2) for name, param in self.named_parameters(): if 'o_net' in name and name.endswith('weight'): - torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * n_layers)) + torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers)) def reset_cache(self, use_cache=False): for layer in self.layers: @@ -648,13 +691,22 @@ def _get_layer_inputs( if multi_encoder_mapping[idx] is None: return None, None, None else: + _attn_prior = attn_prior[multi_encoder_mapping[idx]] if attn_prior is not None else None + if isinstance(_attn_prior, list): + # @pneekhara: This means, we are passing layerwise attn_prior + _attn_prior = _attn_prior[idx] return ( cond[multi_encoder_mapping[idx]], cond_mask[multi_encoder_mapping[idx]] if cond_mask is not None else None, - attn_prior[multi_encoder_mapping[idx]] if attn_prior is not None else None, + _attn_prior, ) else: - return cond, cond_mask, attn_prior + if isinstance(attn_prior, list): + # @pneekhara: This means, we are passing layerwise attn_prior + _attn_prior = attn_prior[idx] + else: + _attn_prior = attn_prior + return cond, cond_mask, _attn_prior def forward( self, @@ -664,6 +716,7 @@ def forward( cond_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, attn_prior: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, multi_encoder_mapping: Optional[List[Optional[int]]] = None, + max_layer_idx: Optional[int] = None, ) -> Dict[str, Union[torch.Tensor, List]]: """ Args: @@ -701,10 +754,9 @@ def forward( x = out_dict['output'] attn_probabilities.append(out_dict['attn_probabilities']) - if self.norm_out is not None: - x = self.norm_out(x) - - if self.dropout_out is not None: - x = self.dropout_out(x) + if max_layer_idx is not None and idx == max_layer_idx: + break + x = self.norm_out(x) + x = self.dropout_out(x) return {'output': x, 'attn_probabilities': attn_probabilities} diff --git a/nemo/collections/tts/modules/utmosv2.py b/nemo/collections/tts/modules/utmosv2.py new file mode 100644 index 000000000000..1da212a75da0 --- /dev/null +++ b/nemo/collections/tts/modules/utmosv2.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import utmosv2 +except ImportError: + raise ImportError( + "UTMOSv2 is not installed. Please install it using `pip install git+https://github.com/sarulab-speech/UTMOSv2.git@v1.2.1`." + ) +from typing import Optional +import torch +from threadpoolctl import threadpool_limits + +""" +Uses the UTMOSv2 model to estimate the MOS of a speech audio file. +""" + + +class UTMOSv2Calculator: + """ + Wrapper around UTMOSv2 MOS estimator to make it easy to use. + Args: + device: The device to place the model on. If None, the best available device will be used. + Default is None. + """ + + def __init__(self, device: Optional[str] = None): + if device is None: + device = get_available_device() + self.model = utmosv2.create_model() + self.model.eval() + self.model.to(torch.device(device)) + + def __call__(self, file_path): + """ + Estimate the MOS of the given speech audio file using UTMOSv2. + """ + with torch.inference_mode(): + # UTMOSv2 tends to launch many OpenMP threads which can overload the machine's CPUs + # without actually speeding up prediction. Limit to 4 threads. + with threadpool_limits(limits=4): + mos_score = self.model.predict(input_path=file_path, num_repetitions=1, num_workers=0) + return mos_score + + def process_directory(self, input_dir: str, batch_size: int = 16) -> list[dict[str, str | float]]: + """ + Computes UTMOSv2 scores for all `*.wav` files in the given directory. + Args: + input_dir: The directory containing the audio files. + batch_size: The number of audio files to process in parallel. + Returns: + A list of dictionaries, each containing the file path and the UTMOSv2 score. + """ + with torch.inference_mode(): + # UTMOSV2 tends to launch many of OpenMP threads which overloads the machine's CPUs + # while actually slowing down the prediction. Limit the number of threads here. + with threadpool_limits(limits=1): + results = self.model.predict( + input_dir=input_dir, num_repetitions=1, num_workers=batch_size, batch_size=batch_size + ) + return results + + +def get_available_device(): + """ + Get the best available device (prefer GPU, fallback to CPU). + """ + if torch.cuda.is_available(): + return "cuda:0" # Use first GPU + else: + return "cpu" diff --git a/nemo/collections/tts/parts/preprocessing/feature_processors.py b/nemo/collections/tts/parts/preprocessing/feature_processors.py index 19ed8139ae65..ccbc2057101f 100644 --- a/nemo/collections/tts/parts/preprocessing/feature_processors.py +++ b/nemo/collections/tts/parts/preprocessing/feature_processors.py @@ -19,10 +19,7 @@ import torch -from nemo.utils.decorators import experimental - -@experimental class FeatureProcessor(ABC): @abstractmethod def process(self, training_example: dict) -> None: diff --git a/nemo/collections/tts/parts/preprocessing/features.py b/nemo/collections/tts/parts/preprocessing/features.py index 5067e89f52f8..802f3d2f079c 100644 --- a/nemo/collections/tts/parts/preprocessing/features.py +++ b/nemo/collections/tts/parts/preprocessing/features.py @@ -25,10 +25,8 @@ from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor from nemo.collections.tts.parts.utils.tts_dataset_utils import get_audio_filepaths, normalize_volume, stack_tensors -from nemo.utils.decorators import experimental -@experimental class Featurizer(ABC): @abstractmethod def save(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path, overwrite: bool = True) -> None: @@ -78,7 +76,10 @@ def _get_feature_filepath( def _features_exists( - feature_names: List[Optional[str]], manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path, + feature_names: List[Optional[str]], + manifest_entry: Dict[str, Any], + audio_dir: Path, + feature_dir: Path, ) -> bool: for feature_name in feature_names: if feature_name is None: diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index 1856dee0ce0f..f0d2dc363237 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -31,7 +31,6 @@ from nemo.collections.tts.parts.utils.helpers import create_plot from nemo.utils import logging -from nemo.utils.decorators import experimental HAVE_WANDB = True try: @@ -127,7 +126,6 @@ def generate_artifacts( """ -@experimental class LoggingCallback(Callback): """ Callback which can log artifacts (eg. model predictions, graphs) to local disk, Tensorboard, and/or WandB. diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index 2dc555872d10..98ddb0b1ef18 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -134,14 +134,14 @@ def binarize_attention_parallel(attn, in_lens, out_lens): def get_mask_from_lengths( - lengths: Optional[torch.Tensor] = None, - x: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, x: Optional[torch.Tensor] = None, pad_to_factor: Optional[int] = None ) -> torch.Tensor: """Constructs binary mask from a 1D torch tensor of input lengths Args: lengths: Optional[torch.tensor] (torch.tensor): 1D tensor with lengths x: Optional[torch.tensor] = tensor to be used on, last dimension is for mask + pad_to_factor: Optional[int] = pad the mask to an integer multiple of this factor Returns: mask (torch.tensor): num_sequences x max_length binary tensor """ @@ -153,6 +153,8 @@ def get_mask_from_lengths( max_len = torch.max(lengths) else: max_len = x.shape[-1] + if pad_to_factor is not None: + max_len = torch.ceil(max_len / pad_to_factor) * pad_to_factor ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) mask = ids < lengths.unsqueeze(1) return mask @@ -440,12 +442,15 @@ def tacotron2_log_to_wandb_func( swriter.log({"audios": audios}) -def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vmin=None, vmax=None): +def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vmin=None, vmax=None, attended=None): if phoneme_seq: fig, ax = plt.subplots(figsize=(15, 10)) else: fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none', vmin=vmin, vmax=vmax) + if attended is not None: + for step in range(len(attended) - 1): + plt.plot([step, step + 1], [attended[step], attended[step + 1]], color='red', linewidth=1, linestyle='--') ax.set_title(title) fig.colorbar(im, ax=ax) xlabel = 'Decoder timestep' diff --git a/nemo/utils/nemo_logging.py b/nemo/utils/nemo_logging.py index bcc7ad199603..d19d5b37c30b 100644 --- a/nemo/utils/nemo_logging.py +++ b/nemo/utils/nemo_logging.py @@ -31,11 +31,14 @@ class LogMode(enum.IntEnum): + """Enum to control how many times to log messages in NeMo logging""" + EACH = 0 # Log the message each time ONCE = 1 # Log the message only once. The same message will not be logged again. class Logger(metaclass=Singleton): + """NeMo's logging class. Makes some changes on top of python's logging module to aid model devs.""" # Level 0 NOTSET = _logging.NOTSET @@ -378,7 +381,7 @@ def debug(self, msg, *args, mode=LogMode.EACH, **kwargs): logger.debug("Houston, we have a %s", "thorny problem", exc_info=1) """ if self._logger is not None and self._logger.isEnabledFor(Logger.DEBUG) and not self._logged_once(msg, mode): - self._logger._log(Logger.DEBUG, msg, args, **kwargs) + self._logger._log(Logger.DEBUG, msg, args, **kwargs, stacklevel=2) def info(self, msg, *args, mode=LogMode.EACH, **kwargs): """ @@ -390,7 +393,7 @@ def info(self, msg, *args, mode=LogMode.EACH, **kwargs): logger.info("Houston, we have a %s", "interesting problem", exc_info=1) """ if self._logger is not None and self._logger.isEnabledFor(Logger.INFO) and not self._logged_once(msg, mode): - self._logger._log(Logger.INFO, msg, args, **kwargs) + self._logger._log(Logger.INFO, msg, args, **kwargs, stacklevel=2) def warning(self, msg, *args, mode=LogMode.EACH, **kwargs): """ @@ -402,7 +405,7 @@ def warning(self, msg, *args, mode=LogMode.EACH, **kwargs): logger.warning("Houston, we have a %s", "bit of a problem", exc_info=1) """ if self._logger is not None and self._logger.isEnabledFor(Logger.WARNING) and not self._logged_once(msg, mode): - self._logger._log(Logger.WARNING, msg, args, **kwargs) + self._logger._log(Logger.WARNING, msg, args, **kwargs, stacklevel=2) def error(self, msg, *args, mode=LogMode.EACH, **kwargs): """ @@ -414,7 +417,7 @@ def error(self, msg, *args, mode=LogMode.EACH, **kwargs): logger.error("Houston, we have a %s", "major problem", exc_info=1) """ if self._logger is not None and self._logger.isEnabledFor(Logger.ERROR) and not self._logged_once(msg, mode): - self._logger._log(Logger.ERROR, msg, args, **kwargs) + self._logger._log(Logger.ERROR, msg, args, **kwargs, stacklevel=2) def critical(self, msg, *args, mode=LogMode.EACH, **kwargs): """ @@ -430,4 +433,4 @@ def critical(self, msg, *args, mode=LogMode.EACH, **kwargs): and self._logger.isEnabledFor(Logger.CRITICAL) and not self._logged_once(msg, mode) ): - self._logger._log(Logger.CRITICAL, msg, args, **kwargs) + self._logger._log(Logger.CRITICAL, msg, args, **kwargs, stacklevel=2) diff --git a/requirements/requirements_tts.txt b/requirements/requirements_tts.txt index af2a76932779..927f493e652e 100644 --- a/requirements/requirements_tts.txt +++ b/requirements/requirements_tts.txt @@ -13,4 +13,5 @@ pandas pypinyin pypinyin-dict seaborn +utmosv2 @ git+https://github.com/sarulab-speech/UTMOSv2.git@v1.2.1 diff --git a/scripts/magpietts/README_magpie_po.md b/scripts/magpietts/README_magpie_po.md new file mode 100644 index 000000000000..897287aaf1d5 --- /dev/null +++ b/scripts/magpietts/README_magpie_po.md @@ -0,0 +1,248 @@ +### Offline Preference Alignment (DPO/RPO) + +Code: `nemo/collections/tts/models/magpietts_preference_optimization.py` + +Preference Alignment (DPO/RPO) involves the following steps +1) Create a list of text-context pairs for which we will generate preference data. +2) For each text-context pair generate multiple audios from a base TTS checkpoint and calculate metrics (CER/SSIM) for each generation. +3) Create chosen-rejected pairs from the generated audio. +4) Finetune the base TTS checkpoint on the chosen-rejected pairs. + +#### 1. Create text-context pairs +We pair a list of challenging texts with context audios from our speech datasets. We add a similar number of regular transcripts our datasets such as LibriTTS paired with random context audios. We also include examples with text contexts. There are other options for generating text-context pairs. + +``` +python scripts/magpietts/dpo/create_text_contextpairs.py \ + --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ + --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ + --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ + --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ + --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ + --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \ + --nsamples_perpair 6 ; +``` +Each pair is repeated `nsamples_perpair` times which specifies how many samples we want to generate for each pair. The output manifest serves as the input for the next step. + +We can also explore other options for these text-context pairs as well depending on the task. + +#### 2. Generate audios for each text-context pair + +Next, we can generate audios from a base TTS checkpoint using the following command. We pass the `audio_dir` as "/" since our text context pairs contains absolute paths. Model config arguments should be modified accordingly to match the base checkpoint architecture. We can run the below command on cluster to generate audios across multiple nodes. This command saves the generated audios along with the metrics for each generation in the `exp_dir`. Each generated audio file is accompanied with a `.json` file that has the CER/SSIM metrics. + + +``` +python examples/tts/magpietts.py \ +--config-name=magpietts_po_inference \ +mode=test \ +batch_size=64 \ ++init_from_ptl_ckpt= \ +exp_manager.exp_dir= \ ++test_ds_meta.textcontextpairs.manifest_path= \ ++test_ds_meta.textcontextpairs.audio_dir="/" \ ++test_ds_meta.textcontextpairs.feature_dir="/" \ +model.model_type="decoder_context_tts" # Change this as needed \ +model.context_duration_min=5.0 \ +model.context_duration_max=5.0 \ +model.use_text_conditioning_encoder=true \ +model.codecmodel_path= \ +model.alignment_loss_scale=0.002 \ +model.prior_scaling_factor=null \ +model.load_cached_codes_if_available=false +``` +#### 3. Create chosen-rejected pairs from the generations + +Next, we go through the generated audio directory and create chosen-rejected pairs. + +``` +python scripts/magpietts/dpo/create_preference_pairs.py \ +--input_manifest \ +--generated_audio_dir /MagpieTTS-PO-Infer/version_0/audio \ +--group_size 6 \ +--cer_threshold 0.01 \ +--val_size 256 ; +``` + +`cer_threshold=0.01` means that filter out pairs in which the chosen CER > 0.01. + +This command should save train and val manifests for DPO finetuning in the base directory of the generated_audio_dir, that is, `/MagpieTTS-PO-Infer/version_0/manifests/` + +#### 4. DPO Finetuning Command + +Finally, we perform DPO finetuning using the following command: + +``` +python examples/tts/magpietts.py \ +batch_size=4 \ ++init_from_ptl_ckpt= \ ++mode="dpo_train" \ +max_epochs=10 \ +exp_manager.exp_dir= \ +exp_manager.checkpoint_callback_params.always_save_nemo=false \ +model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ ++train_ds_meta.dpopreftrain.manifest_path="/MagpieTTS-PO-Infer/version_0/manifests/" \ ++train_ds_meta.dpopreftrain.audio_dir="/" \ ++train_ds_meta.dpopreftrain.feature_dir="/" \ ++val_ds_meta.dpoprefval.manifest_path="/MagpieTTS-PO-Infer/version_0/manifests/dpo_val_manifest.json" \ ++val_ds_meta.dpoprefval.audio_dir="/" \ ++val_ds_meta.dpoprefval.feature_dir="/" \ ++model.dpo_beta=0.01 \ ++model.dpo_sft_loss_weight=0.0 \ +model.model_type="decoder_context_tts" # Change this as needed \ +model.context_duration_min=5.0 \ +model.context_duration_max=5.0 \ +model.use_text_conditioning_encoder=true \ +model.codecmodel_path= \ +model.alignment_loss_scale=0.001 \ +model.prior_scaling_factor=null \ +trainer.val_check_interval=200 \ +trainer.log_every_n_steps=10 \ +model.optim.lr=2e-7 \ +~model.optim.sched +``` + +Note the following overrides in the above command: + +``` ++mode="dpo_train" \ +model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +``` + +Again, our manifest contain absolute paths so we specify `audio_dir="/"` . + +### Online Preference Optimization (GRPO) + +For online preference optmization, process is much simpler. + +1) Create a list of text-context pairs for which we will generate preference data (just one pair for a text-context not repeated). +We'll use the same process as above, just set `nsamples_perpair 1` in the command. +``` +python scripts/magpietts/dpo/create_text_contextpairs.py \ + --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ + --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ + --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ + --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ + --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ + --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \ + --nsamples_perpair 1 ; +``` + +2. Train using GRPO + +To train with GRPO, we use a similar training command as the base model training with a few modifications. + +1. We start from a pretrained checkpoint supplied using `+init_from_ptl_ckpt` +2. We add `+mode="onlinepo_train"` to specify preference optimization based training. +3. Use a small batch size (bs=2) since we generate `num_generations_per_item` samples per item in the batch and the effective batch size becomes `bs*num_generations_per_item` +4. The manifest should contain absolute audio paths and the `audio_dir` is specified as "/" in the `train_ds_meta` command. +5. Use the same model specific overrides as the base model (eg. x-attn heads, is_causal, num_layers, local transformer etc). +6. Set dropout probs to 0 for all modules - This is especially important if we are not using reference free mode. KL divergence loss becomes very spiky and unstable. Set prob to 0 by `model.decoder.p_dropout=0.0`. +7. Dont use attention prior or CTC loss during GRPO. +8. Add the following GRPO specific arguments in the training command. + +``` ++model.grpo_beta=0.0 \ # Coeffecient for KL loss (if not using reference free mode) ++model.num_generations_per_item=12 \ # 12 samples generated for each item and we compute reward for each ++model.reference_free=true \ # Reference free means we dont use KL loss term. Only optimize for rewards ++model.inference_cfg_prob=0.0 \ # fraction of generations generated using CFG. Can set > 0.0 if we want to optimize for both CFG and non CFG modes of generation ++model.use_local_transformer_prob=0.5 \ # fraction of generations generated using Local Transformer. Set it between 0.0 and 1.0 to improve both LT outputs and non LT outputs for models with an LT ++model.inference_cfg_scale=2.5 \ # CFG scale for samples generated using CFG ++model.cer_reward_weight=0.33 \ # weightage of CER reward in the overall reward ++model.ssim_reward_weight=0.33 \ # weightage of SSIM reward in the overall reward ++model.pesq_reward_weight=0.33 \ # weightage of PESQ reward in the overall reward ++model.use_pesq=true \ # set this is true is using pesq reward ++model.reward_asr_model="whisper" \ # Use whisper only for multilingual settings, dont specify for English +model.cfg_unconditional_prob=0.0 \ # Set this to 0, we dont want want to drop out unconditional input ++model.inference_topk=2016 \ # Top-K - Not yet sure if we should use topk=80 or not. top_k 2016 just disable top_k in a way. ++model.inference_temperature=0.8 \ # Slightly higher temperature for more variety of generations in preference optimization ++model.use_kv_cache_during_online_po=true \ # Use KV caching while generating samples for GRPO ++model.loss_type="grpo" \ # can be grpo or dr_grpo. grpo works better in my experiments. ++model.scale_rewards=true \ # Whether to divide advantages by std deviation or not (set true for GRPO and false for DR_GRPO) ++model.max_decoder_steps=430 \ # Max steps for generation +``` + +9. We also want to validate more frequently during GRPO since each step takes longer. So we add the following args. +``` +~trainer.check_val_every_n_epoch \ ++trainer.val_check_interval=50 \ +``` + +10. We use a lower learning rate and save the best checkpoints based on lowest CER on our validation set using: +``` +model.optim.lr=1e-7 \ +~model.optim.sched \ +exp_manager.checkpoint_callback_params.monitor="val_cer_gt" \ +exp_manager.checkpoint_callback_params.mode="min" \ +``` + +11. Specify precision and gradient clipping as necessary +``` +trainer.precision=32 \ ++trainer.gradient_clip_val=2.5 \ +``` + + +Below is a sample training command for multilingual GRPO: + +``` +python examples/tts/magpietts.py \ +--config-name=magpietts_multilingual_v1 #TODO(blisc) after updating yamls\ +batch_size=2 \ ++init_from_ptl_ckpt= \ ++mode="onlinepo_train" \ ++model.text_tokenizers.chartokenizer._target_=AutoTokenizer # Change this as needed \ ++model.text_tokenizers.chartokenizer.pretrained_model="google/byt5-small" # Change this as needed \ +max_epochs=20 \ +exp_manager.exp_dir= \ ++exp_manager.version=0 \ +exp_manager.checkpoint_callback_params.always_save_nemo=false \ ++train_ds_meta.dpopreftrain.manifest_path= \ ++train_ds_meta.dpopreftrain.audio_dir="/" \ ++train_ds_meta.dpopreftrain.feature_dir="/" \ ++train_ds_meta.dpopreftrain.tokenizer_names="[chartokenizer]" #Change this as needed \ ++val_ds_meta.dpoprefval.manifest_path= \ ++val_ds_meta.dpoprefval.audio_dir="/" \ ++val_ds_meta.dpoprefval.feature_dir="/" \ ++val_ds_meta.dpoprefval.tokenizer_names="[chartokenizer]" #Change this as needed \ ++model.grpo_beta=0.0 \ ++model.num_generations_per_item=12 \ ++model.reference_free=true \ ++model.inference_cfg_prob=0.0 \ ++model.inference_cfg_scale=2.5 \ ++model.cer_reward_weight=0.33 \ ++model.ssim_reward_weight=0.33 \ ++model.pesq_reward_weight=0.33 \ ++model.use_pesq=true \ ++model.reward_asr_model="whisper" \ +model.cfg_unconditional_prob=0.0 \ ++model.inference_topk=2016 \ ++model.inference_temperature=0.8 \ ++model.use_kv_cache_during_online_po=true \ ++model.loss_type="grpo" \ ++model.scale_rewards=true \ ++model.max_decoder_steps=430 \ +model.model_type="decoder_context_tts" #Change this as needed \ +model.context_duration_min=5.0 \ +model.context_duration_max=5.0 \ +model.decoder.p_dropout=0.0 \ +model.encoder.p_dropout=0.0 \ +model.local_transformer_type="autoregressive" #Change this as needed \ +model.local_transformer_n_layers=1 #Change this as needed \ +model.local_transformer_n_heads=1 #Change this as needed \ +model.local_transformer_hidden_dim=256 #Change this as needed \ +model.use_text_conditioning_encoder=true #Change this as needed \ +model.codecmodel_path= \ +model.alignment_loss_scale=0.0 \ +model.prior_scaling_factor=null \ +~trainer.check_val_every_n_epoch \ ++trainer.val_check_interval=50 \ +trainer.log_every_n_steps=10 \ +model.optim.lr=1e-7 \ +~model.optim.sched \ +exp_manager.checkpoint_callback_params.monitor="val_cer_gt" \ +exp_manager.checkpoint_callback_params.mode="min" \ +trainer.precision=32 \ ++trainer.gradient_clip_val=2.5 \ +``` + diff --git a/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py b/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py new file mode 100644 index 000000000000..f3acb91848f8 --- /dev/null +++ b/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py @@ -0,0 +1,561 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script requires the following updates to lhotse: add `shard_offset` in lhotse's writers. +$ pip install git+https://github.com/lhotse-speech/lhotse.git@883c24b5f6cdc4bbc73e89186e99f7907262b59c + +Example of manifest: + { + "audio_filepath": "train-clean-360/4098/11547/4098_11547_000032_000000.wav", + "text": "\"Isn't it?\" queried Theo.", + "speaker": "| Language:en Dataset:LibriTTS Speaker:4098 |", + "chapter_id": "11547", + "utter_id": "000032_000000", + "duration": 1.9700416666666667, + "normalized_text": "\"Isn't it?\" queried Theo.", + "context_speaker_similarity": 0.7800518870353699, + "context_audio_filepath": "train-clean-360/4098/11547/4098_11547_000031_000001.wav", + "context_audio_duration": 9.45 + } + +Example usage: + python scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py \ + --manifest-path ${MANIFEST} \ + --audio-base-dir ${AUDIO_BASE_DIR} \ + --output-dir ${OUTPUT_DIR} \ + --num-jobs ${NUM_JOBS} \ + --processing-chunk-size ${CHUNK_SIZE} \ + --audio-format ${AUDIO_FORMAT} \ + --log-level ${LOG_LEVEL} \ + --shuffle \ + --shuffle-seed 42 \ + 2>&1 | tee ./log/create_lhotse_shar_from_nemo_manifest.stdout + +Expected output: + $ tree ${OUTPUT_DIR} + ${OUTPUT_DIR}/ + cuts/ + cuts.000000.jsonl.gz + cuts.000001.jsonl.gz + ... + target_audio/ + recording.000000.tar + recording.000001.tar + ... + context_audio/ + recording.000000.tar + recording.000001.tar + ... +""" + +import argparse +import itertools +import logging +import math +import os +import random +import re +from concurrent.futures import ProcessPoolExecutor, as_completed +from functools import partial +from pathlib import Path +from typing import Any, Dict, Tuple + +from lhotse import AudioSource, MonoCut, Recording, SupervisionSegment, compute_num_samples, fastcopy +from lhotse.serialization import load_jsonl +from lhotse.shar.writers import AudioTarWriter, JsonlShardWriter +from tqdm import tqdm + +NEMO_KEYS_NO_NEED_TO_LOG_IN_CUSTOM_FIELDS_FOR_SUPERVISION = [ + "audio_filepath", + "context_audio_filepath", + "text", + "offset", + "duration", + "speaker", +] + + +def to_shar_placeholder(recording: Recording, cut: MonoCut) -> Recording: + """this function is borrowed from lhotse.shar.writers.to_shar_placeholder. The only change for Recording instance is to update the id with cut.id.""" + return fastcopy( + recording, + id=cut.id, + # Creates a single AudioSource out of multiple ones. + sources=[AudioSource(type="shar", channels=recording.channel_ids, source="")], + # Removes the transform metadata because they were already executed. + transforms=None, + duration=cut.duration, + num_samples=compute_num_samples(cut.duration, recording.sampling_rate), + ) + + +def check_speaker_format(item: str): + """Enforce speaker format like '| Language:en Dataset:HiFiTTS Speaker:9136_other |'""" + pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" + if not isinstance(item, str): + return False + return bool(re.match(pattern, item)) + + +def get_recording_id(relative_path: str) -> str: + """Generate a recording ID from the relative audio path.""" + return "rec-" + relative_path.rsplit(".", 1)[0].replace("/", "-") + + +def process_manifest_entry(entry: Dict[str, Any], audio_base_dir: Path) -> Tuple[MonoCut, MonoCut] | None: + """ + Processes a single entry from the NeMo manifest to create Lhotse objects. + + Returns: + tuple: (target_cut, context_cut) or None if an error occurs. + """ + try: + # Required fields + target_audio_path_relative = entry.get("audio_filepath") + context_audio_path_relative = entry.get("context_audio_filepath") + target_audio_duration = entry.get("duration") + context_audio_duration = entry.get("context_audio_duration") + text = entry.get("text") + # observed cases when text is empty while normalized_text is not. + if not text or not text.strip(): + text = entry.get("normalized_text") + speaker = entry.get("speaker") + + # Check required fields + if not all( + [ + target_audio_path_relative, + context_audio_path_relative, + target_audio_duration, + context_audio_duration, + text, + speaker, + ] + ): + logging.warning(f"Skipping entry due to missing fields: {entry}") + return None + + # Check speaker format + if not check_speaker_format(speaker): + logging.warning(f"Skipping entry due to incorrect speaker format: {entry}") + return None + + target_audio_filepath = audio_base_dir / target_audio_path_relative + context_audio_filepath = audio_base_dir / context_audio_path_relative + + if not target_audio_filepath.is_file(): + logging.warning( + f"Skipping entry due to missing target audio file: {target_audio_filepath} from entry: {entry}" + ) + return None + if not context_audio_filepath.is_file(): + logging.warning( + f"Skipping entry due to missing context audio file: {context_audio_filepath} from entry: {entry}" + ) + return None + + # Create IDs + target_recording_id = get_recording_id(target_audio_path_relative) + context_recording_id = get_recording_id(context_audio_path_relative) + + # Create Recordings + # TODO: if input is FLAC, then we should set AudioSegment.from_file(int_values=True). Does this applies to lhotse? + target_recording = Recording.from_file(target_audio_filepath, recording_id=target_recording_id) + context_recording = Recording.from_file(context_audio_filepath, recording_id=context_recording_id) + + # Custom fields exist in manifests, so better to keep them for future usage. + custom_fields = { + key: val + for key, val in entry.items() + if key not in NEMO_KEYS_NO_NEED_TO_LOG_IN_CUSTOM_FIELDS_FOR_SUPERVISION + } + custom_fields["context_recording_id"] = context_recording_id + + # Extract language from speaker string + lang_match = re.search(r"Language:(\w+)", speaker) + language = lang_match.group(1) if lang_match else None + + # offset in seconds + target_offset_in_seconds = entry.get("offset", 0.0) + context_offset_in_seconds = entry.get("context_audio_offset", 0.0) + + # Create Supervision for target cut. We constrain one supervision per cut for now. + supervision = SupervisionSegment( + id=f"sup-{target_recording_id}", + recording_id=target_recording_id, + start=target_offset_in_seconds, + duration=target_audio_duration, # duration from manifest + channel=0, # only support mono audio for now + text=text, + language=language, + speaker=speaker, + custom=custom_fields, + ) + + # Create target cut + target_cut_id = f"cut-{target_recording_id}-{target_offset_in_seconds:.2f}-{target_audio_duration:.2f}" + target_cut = MonoCut( + id=target_cut_id, + start=target_offset_in_seconds, + duration=target_audio_duration, + channel=0, # only support mono audio for now + recording=target_recording, + supervisions=[supervision], + ) + if not math.isclose(target_cut.duration, target_audio_duration, abs_tol=0.1): + logging.warning( + f"Manifest duration ({target_audio_duration}) differs significantly from cut duration ({target_cut.duration}) for {target_recording_id}. Using cut duration." + ) + target_cut.supervisions[0].duration = target_cut.duration + + # Create context cut. This cut is only used to load segmented audio and would not be stored in the final manifest. + context_cut_id = ( + f"context_cut-{context_recording_id}-{context_offset_in_seconds:.2f}-{context_audio_duration:.2f}" + ) + if context_cut_id.split("-", 1)[1] == target_cut_id.split("-", 1)[1]: + logging.warning(f"Context cut has the same recording segment as target cut. Skipping entry: {entry}") + return None + + context_cut = MonoCut( + id=context_cut_id, + start=context_offset_in_seconds, + duration=context_audio_duration, + channel=0, # only support mono audio for now + recording=context_recording, + ) + return target_cut, context_cut + + except Exception as e: + logging.error(f"Skipping entry due to error during metadata processing: {entry}: {e}", exc_info=True) + return None + + +def shuffle_jsonl_file(input_path: Path, seed: int = None) -> Path: + """ + Shuffle lines in a JSONL file and write to a shuffled copy. + + Args: + input_path: Path to the original JSONL file + seed: Random seed for reproducible shuffling + + Returns: + Path to the shuffled file + """ + if seed is not None: + random.seed(seed) + + logging.info(f"Reading and shuffling manifest entries from {input_path}") + + # Read all lines into memory + with open(input_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + + logging.info(f"Loaded {len(lines)} entries, now shuffling...") + + # Shuffle the lines + random.shuffle(lines) + + # Create output path with "_shuffled" suffix + shuffled_path = input_path.parent / f"{input_path.stem}_shuffled{input_path.suffix}" + + # Write shuffled content + with open(shuffled_path, 'w', encoding='utf-8') as f: + f.writelines(lines) + + logging.info(f"Shuffled manifest written to: {shuffled_path}") + return shuffled_path + + +def chunked_iterator(iterable, chunk_size): + """Yield successive chunks from iterable.""" + _it = iter(iterable) + while _chunk := tuple(itertools.islice(_it, chunk_size)): + yield _chunk + + +def process_and_write_chunk( + manifest_chunk_with_idx: Tuple[int, Tuple[Dict[str, Any], ...]], + audio_base_dir: Path, + output_dir: Path, + audio_format: str, +) -> Dict[str, int]: + """ + Processes a chunk of manifest entries, loads audio, and writes corresponding + single shard files for cuts, target audio, and context audio. + Designed to be run in a parallel worker process. + Loads and writes audio iteratively to save memory. + + Returns a dict containing processing stats like 'processed', 'initial_errors', 'audio_load_errors'. + """ + chunk_idx, manifest_chunk = manifest_chunk_with_idx + worker_pid = os.getpid() + logging.debug(f"[Worker {worker_pid}, Chunk {chunk_idx}] Starting processing {len(manifest_chunk)} entries.") + + # --- 1. Process manifest entries to get Cut objects --- + chunk_metadata = [] + initial_errors = 0 + for entry in manifest_chunk: + result = process_manifest_entry(entry, audio_base_dir=audio_base_dir) + if result is not None: + chunk_metadata.append(result) + else: + initial_errors += 1 + + if not chunk_metadata: + logging.warning( + f"[Worker {worker_pid}, Chunk {chunk_idx}] No valid entries after initial processing. Skipping." + ) + return {"processed": 0, "initial_errors": initial_errors, "audio_load_errors": 0, "write_errors": 0} + + logging.debug( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Collected {len(chunk_metadata)} cut pairs after initial processing." + ) + + # --- 2. Initialize writers and perform iterative load-and-write --- + cuts_dir = output_dir / "cuts" + target_recordings_dir = output_dir / "target_audio" + context_recordings_dir = output_dir / "context_audio" + + cuts_pattern = str(cuts_dir / "cuts.%06d.jsonl.gz") + target_rec_pattern = str(target_recordings_dir / "recording.%06d.tar") + context_rec_pattern = str(context_recordings_dir / "recording.%06d.tar") + + chunk_processed_count = 0 + chunk_audio_load_errors = 0 # Errors during audio loading phase for this chunk + chunk_write_errors = 0 # Errors during write phase for this chunk + + logging.debug( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Initializing writers with offset {chunk_idx} and processing {len(chunk_metadata)} pairs iteratively..." + ) + try: + # Specify shard_size with len(chunk_metadata) and shard_offset with chunk_idx, ensuring each chunk is written to a separate shard file. + shard_size_for_worker = len(chunk_metadata) + with ( + JsonlShardWriter( + pattern=cuts_pattern, shard_size=shard_size_for_worker, shard_offset=chunk_idx + ) as cut_writer, + AudioTarWriter( + pattern=target_rec_pattern, + shard_size=shard_size_for_worker, + format=audio_format, + shard_offset=chunk_idx, + ) as target_rec_writer, + AudioTarWriter( + pattern=context_rec_pattern, + shard_size=shard_size_for_worker, + format=audio_format, + shard_offset=chunk_idx, + ) as context_rec_writer, + ): + # Iterate directly over chunk_metadata + for target_cut, context_cut in chunk_metadata: + # 1. load target/context audio given the audio offset + try: + target_audio = target_cut.load_audio() + context_audio = context_cut.load_audio() + except Exception as e: + logging.error( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Error loading target/context audio for cut {target_cut}: {e}", + exc_info=True, + ) + chunk_audio_load_errors += 1 + continue + + # 2. Write target audio and context audio + try: + target_rec_writer.write( + key=target_cut.id, + value=target_audio, + sampling_rate=target_cut.sampling_rate, + manifest=to_shar_placeholder( + target_cut.recording, target_cut + ), # update manifest.id with target_cut.id that has the audio offset and duration + ) + context_rec_writer.write( + key=target_cut.id, # use target cut id as key for context audio to ensure reference + value=context_audio, + sampling_rate=context_cut.sampling_rate, + manifest=to_shar_placeholder( + context_cut.recording, context_cut + ), # update manifest.id with context_cut.id that has the audio offset and duration + ) + except Exception as e: + logging.error( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Error writing target/context audio for target cut {target_cut}: {e}", + exc_info=True, + ) + chunk_write_errors += 1 + continue + + # 3. write cut metadata + try: + cut_writer.write(target_cut) + except Exception as e: + logging.error( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Error writing cut metadata for cut {target_cut}: {e}", + exc_info=True, + ) + chunk_write_errors += 1 + continue + + chunk_processed_count += 1 + + except Exception as e: + logging.error( + f"[Worker {worker_pid}, Chunk {chunk_idx}] CRITICAL error during writer initialization: {e}", exc_info=True + ) + chunk_write_errors = len(chunk_metadata) + chunk_processed_count = 0 + + # This part is only reached if the main try block completes without critical errors + logging.debug( + f"[Worker {worker_pid}, Chunk {chunk_idx}] Finished chunk. Processed: {chunk_processed_count}, Audio Load Errors: {chunk_audio_load_errors}, Write Errors: {chunk_write_errors}" + ) + + return { + "processed": chunk_processed_count, + "initial_errors": initial_errors, # Errors from initial metadata processing + "audio_load_errors": chunk_audio_load_errors, + "write_errors": chunk_write_errors, + } + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Convert NeMo manifest to sharded Lhotse JSONL/TARs using parallel workers per chunk.", + ) + parser.add_argument("--manifest-path", required=True, type=Path, help="Path to the input NeMo JSON manifest file.") + parser.add_argument( + "--audio-base-dir", required=True, type=Path, help="Base directory where audio files are located." + ) + parser.add_argument("--output-dir", required=True, type=Path, help="Base directory to save the sharded outputs.") + parser.add_argument( + "--num-jobs", + type=int, + default=max(1, os.cpu_count() // 2), + help="Number of parallel worker processes (each processing a whole chunk/shard).", + ) + parser.add_argument( + "--processing-chunk-size", + type=int, + default=4000, + help="Number of manifest entries per chunk (effectively the items per output shard file).", + ) + parser.add_argument( + "--audio-format", type=str, default="flac", help="Audio format for TAR writers (e.g., flac, wav, opus)." + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level for the main process and workers.", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="Shuffle the manifest entries before processing.", + ) + parser.add_argument( + "--shuffle-seed", + type=int, + default=None, + help="Random seed for reproducible shuffling (only used if --shuffle is enabled).", + ) + + args = parser.parse_args() + + # Configure logging based on argument + log_level = getattr(logging, args.log_level.upper(), logging.INFO) + log_format = '%(asctime)s - PID:%(process)d - %(levelname)s - %(message)s' + logging.basicConfig(level=log_level, format=log_format) + + # Ensure output directories exist + cuts_dir = args.output_dir / "cuts" + target_recordings_dir = args.output_dir / "target_audio" + context_recordings_dir = args.output_dir / "context_audio" + cuts_dir.mkdir(parents=True, exist_ok=True) + target_recordings_dir.mkdir(parents=True, exist_ok=True) + context_recordings_dir.mkdir(parents=True, exist_ok=True) + + # Handle shuffling if requested + if args.shuffle: + logging.info(f"Shuffling manifest entries from: {args.manifest_path}") + shuffled_manifest_path = shuffle_jsonl_file(args.manifest_path, seed=args.shuffle_seed) + manifest_iterable = load_jsonl(shuffled_manifest_path) + logging.info(f"Using shuffled manifest for processing: {shuffled_manifest_path}") + else: + logging.info(f"Reading NeMo manifest lazily from: {args.manifest_path}") + manifest_iterable = load_jsonl(args.manifest_path) + + logging.info( + f"Processing manifest in chunks of {args.processing_chunk_size} using {args.num_jobs} parallel workers..." + ) + + total_processed_count = 0 + total_initial_errors = 0 + total_audio_load_errors = 0 + total_write_errors = 0 + num_chunks = 0 + + worker_func = partial( + process_and_write_chunk, + audio_base_dir=args.audio_base_dir, + output_dir=args.output_dir, + audio_format=args.audio_format, + ) + + with ProcessPoolExecutor(max_workers=args.num_jobs) as executor: + # Enumerate chunks to pass index to worker. Each index is the same as the shard_offset. + chunk_iterator = enumerate(chunked_iterator(manifest_iterable, args.processing_chunk_size)) + futures = { + executor.submit(worker_func, chunk_with_idx): chunk_with_idx[0] for chunk_with_idx in chunk_iterator + } + num_chunks = len(futures) + + logging.info(f"Submitted {num_chunks} chunks to workers.") + + for future in tqdm(as_completed(futures), total=num_chunks, desc="Processing Chunks"): + chunk_idx = futures[future] + try: + result = future.result() + total_processed_count += result["processed"] + total_initial_errors += result["initial_errors"] + total_audio_load_errors += result["audio_load_errors"] + total_write_errors += result["write_errors"] + logging.debug(f"Chunk {chunk_idx} finished with result: {result}") + except Exception as e: + logging.error(f"Chunk {chunk_idx} failed with exception: {e}", exc_info=True) + # Increment error count based on chunk size. Difficult to know precisely. Assume all failed. + total_initial_errors += args.processing_chunk_size + + logging.info("=" * 30 + " Processing Summary " + "=" * 30) + logging.info(f"Total chunks processed: {num_chunks}") + logging.info(f"Successfully processed and wrote data for approximately {total_processed_count} entries.") + total_errors = total_initial_errors + total_audio_load_errors + total_write_errors + if total_errors > 0: + logging.warning(f"Encountered errors/skips in {total_errors} potential entries:") + logging.warning(f" - Initial processing errors/skips: {total_initial_errors}") + logging.warning(f" - Audio loading errors/skips (affecting writes): {total_audio_load_errors}") + logging.warning(f" - Writing errors: {total_write_errors}") + logging.warning("Check logs above (use DEBUG level for more details) for specific entry issues.") + else: + logging.info("No significant errors reported.") + logging.info("Manifest creation finished.") + + +if __name__ == "__main__": + main() diff --git a/scripts/magpietts/dpo/create_preference_pairs.py b/scripts/magpietts/dpo/create_preference_pairs.py index 4d8ed40f3bb0..fc9deed2d69c 100644 --- a/scripts/magpietts/dpo/create_preference_pairs.py +++ b/scripts/magpietts/dpo/create_preference_pairs.py @@ -43,6 +43,12 @@ def main(): ) parser.add_argument("--group_size", type=int, default=4) parser.add_argument("--cer_threshold", type=float, default=0.02) + parser.add_argument( + "--min_length_threshold", + type=float, + default=1.5, + help="Minimum length permitted. Set this shorter to allow very short sentences (which can be useful for DPO tuning.", + ) parser.add_argument("--val_size", type=int, default=64) args = parser.parse_args() @@ -83,7 +89,7 @@ def main(): print("Len all_best_records: ", len(all_best_records)) print("Len all_worst_records: ", len(all_worst_records)) best_records, worst_records = filter_best_and_worst_records( - all_best_records, all_worst_records, args.cer_threshold + all_best_records, all_worst_records, args.cer_threshold, args.min_length_threshold ) print("Len filtered best_records: ", len(best_records)) print("Len filtered worst_records: ", len(worst_records)) @@ -167,6 +173,8 @@ def pareto_rank(items): # A helper function to check if item A is dominated by item B # A: (cerA, ssimA), B: (cerB, ssimB) def is_dominated(A, B): + assert len(A) == 2 + assert len(B) == 2 return (B[0] <= A[0]) and (B[1] >= A[1]) and (B != A) # Equivalently, check at least one strict inequality: # (B[0] < A[0]) or (B[1] > A[1]) @@ -186,7 +194,7 @@ def is_dominated(A, B): dominated = False for j in range(len(remaining)): if i != j: - if is_dominated(remaining[i], remaining[j]): + if is_dominated(remaining[i][:2], remaining[j][:2]): dominated = True break if not dominated: @@ -303,7 +311,7 @@ def create_chosen_rejected_records(records_orig, group_size=6, num_chosen_per_gr return best_records, worst_records -def filter_best_and_worst_records(best_records, worst_records, cer_threshold=0.02): +def filter_best_and_worst_records(best_records, worst_records, cer_threshold=0.02, min_length_threshold=1.5): ridx = 0 filtered_best_records = [] filtered_worst_records = [] @@ -316,7 +324,7 @@ def filter_best_and_worst_records(best_records, worst_records, cer_threshold=0.0 if best_record['cer_gts'] < cer_threshold: worst_record = worst_records[ridx] if (worst_record['duration'] > 19.0 or best_record['duration'] > 19.0) or ( - worst_record['duration'] < 1.5 or best_record['duration'] < 1.5 + worst_record['duration'] < min_length_threshold or best_record['duration'] < min_length_threshold ): skipped_records += 1 ridx += 1 diff --git a/scripts/magpietts/dpo/create_text_contextpairs.py b/scripts/magpietts/dpo/create_text_contextpairs.py index 74ee5ff6b92f..029d44235f38 100644 --- a/scripts/magpietts/dpo/create_text_contextpairs.py +++ b/scripts/magpietts/dpo/create_text_contextpairs.py @@ -36,13 +36,18 @@ def main(): The resulting dataset is saved as a JSON manifest file. Example usage: - python scripts/t5tts/dpo/create_text_contextpairs.py \ - --challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \ - --regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \ - --regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \ - --audio_contexts /Data/DPOPairsInputData/audio_context_list.json \ - --text_contexts /Data/DPOPairsInputData/text_context_list.txt \ - --output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json + python scripts/magpietts/dpo/create_text_contextpairs.py \ + --challenging_texts /Data/DPOPairsInputDatav2/challenging_with_short.txt \ + --regular_texts_for_audiocontext /Data/DPOPairsInputDatav2/regular_texts_for_audiocontext.txt \ + --regular_texts_for_textcontext /Data/DPOPairsInputDatav2/regular_texts_for_textcontext.txt \ + --audio_contexts /Data/DPOPairsInputDatav2/audio_context_list.json \ + --text_contexts /Data/DPOPairsInputDatav2/text_context_list_with_audio.txt \ + --output_manifest /Data/DPOPairsInputDatav2/grpo_train_with_short.json \ + --n_audio_contexts_per_challenging_text 2 \ + --n_text_contexts_per_challenging_text 2 \ + --n_audio_contexts_per_regular_text 1 \ + --n_text_contexts_per_regular_text 1 \ + --nsamples_perpair 1 ; """ parser = argparse.ArgumentParser(description='Create text-context pairs for DPO') parser.add_argument("--challenging_texts", type=str, help="Text file containing challenging texts") @@ -83,8 +88,6 @@ def main(): text_contexts = [text for text in text_contexts if text.strip() != ''] all_records = [] - dummy_audio_filepath = audio_contexts[0]['context_audio_filepath'] - dummy_target_audio_codes_path = audio_contexts[0].get('context_audio_codes_path', None) for challenging_text in challenging_texts: for _ in range(args.n_audio_contexts_per_challenging_text): audio_context = random.choice(audio_contexts) @@ -93,9 +96,7 @@ def main(): for _ in range(args.n_text_contexts_per_challenging_text): text_context = random.choice(text_contexts) - record = create_text_context_record( - challenging_text, text_context, dummy_audio_filepath, 'challenging', dummy_target_audio_codes_path - ) + record = create_text_context_record(challenging_text, text_context, 'challenging') all_records.append(record) for regular_text in regular_texts_for_audiocontext: @@ -107,9 +108,7 @@ def main(): for regular_text in regular_texts_for_textcontext: for _ in range(args.n_text_contexts_per_regular_text): text_context = random.choice(text_contexts) - record = create_text_context_record( - regular_text, text_context, dummy_audio_filepath, 'regular', dummy_target_audio_codes_path - ) + record = create_text_context_record(regular_text, text_context, 'regular') all_records.append(record) random.shuffle(all_records) @@ -151,29 +150,29 @@ def create_audio_context_record(text, audio_context, record_type): return record -def create_text_context_record(text, text_context, dummy_audio_filepath, record_type, target_audio_codes_path=None): +def create_text_context_record(text, text_context, record_type): """ Creates a record for a text-context pair with text context. Args: text (str): The main text content. text_context (str): The associated text context. - dummy_audio_filepath (str): A placeholder audio file path. record_type (str): Type of record ('challenging' or 'regular'). - target_audio_codes_path (str, optional): Optional target audio codes path. Returns: dict: A dictionary representing the text context record. """ + if text_context.endswith("\n"): + text_context = text_context[:-1] record = { 'text': text, 'duration': 6.0, # Does not matter, avoids filtering out in DPO, - 'audio_filepath': dummy_audio_filepath, - 'context_text': text_context, + 'audio_filepath': text_context.split(",")[1], + 'context_text': text_context.split(",")[0], 'record_type': record_type, # challenging or regular } - if target_audio_codes_path is not None: - record['target_audio_codes_path'] = target_audio_codes_path + if text_context.split(",")[-1].endswith(".pt"): + record['target_audio_codes_path'] = text_context.split(",")[-1] return record diff --git a/scripts/magpietts/evalset_config.py b/scripts/magpietts/evalset_config.py new file mode 100644 index 000000000000..2380e8372ca9 --- /dev/null +++ b/scripts/magpietts/evalset_config.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Used as a datafile for infer_and_evaluate.py +""" +dataset_meta_info = { + 'riva_hard_digits': { + 'manifest_path': '/Data/evaluation_manifests/hard-digits-path-corrected.ndjson', + 'audio_dir': '/Data/RIVA-TTS', + 'feature_dir': '/Data/RIVA-TTS', + }, + 'riva_hard_letters': { + 'manifest_path': '/Data/evaluation_manifests/hard-letters-path-corrected.ndjson', + 'audio_dir': '/Data/RIVA-TTS', + 'feature_dir': '/Data/RIVA-TTS', + }, + 'riva_hard_money': { + 'manifest_path': '/Data/evaluation_manifests/hard-money-path-corrected.ndjson', + 'audio_dir': '/Data/RIVA-TTS', + 'feature_dir': '/Data/RIVA-TTS', + }, + 'riva_hard_short': { + 'manifest_path': '/Data/evaluation_manifests/hard-short-path-corrected.ndjson', + 'audio_dir': '/Data/RIVA-TTS', + 'feature_dir': '/Data/RIVA-TTS', + }, + 'vctk': { + 'manifest_path': '/Data/evaluation_manifests/smallvctk__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM_simplet5_withcontextaudiopaths_silence_trimmed.json', + 'audio_dir': '/Data/VCTK-Corpus-0.92', + 'feature_dir': '/Data/VCTK-Corpus-0.92', + }, + 'libritts_seen': { + 'manifest_path': '/Data/evaluation_manifests/LibriTTS_seen_evalset_from_testclean_v2.json', + 'audio_dir': '/Data/LibriTTS', + 'feature_dir': '/Data/LibriTTS', + }, + 'libritts_test_clean': { + 'manifest_path': '/Data/evaluation_manifests/LibriTTS_test_clean_withContextAudioPaths.jsonl', + 'audio_dir': '/Data/LibriTTS', + 'feature_dir': '/Data/LibriTTS', + }, + # We need an4_val_ci just for CI tests + 'an4_val_ci': { + 'manifest_path': '/home/TestData/an4_dataset/an4_val_context_v1.json', + 'audio_dir': '/', + 'feature_dir': None, + }, +} diff --git a/scripts/magpietts/evaluate_generated_audio.py b/scripts/magpietts/evaluate_generated_audio.py new file mode 100644 index 000000000000..a1b25705741b --- /dev/null +++ b/scripts/magpietts/evaluate_generated_audio.py @@ -0,0 +1,488 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Used in infer_and_evaluate.py to obtain metrics such as ASR_WER and UTMOSV2 scores. +""" +import argparse +import json +import logging +import os +import pprint +import string +import tempfile +import time +from contextlib import contextmanager +from functools import partial + +import librosa +import numpy as np +import scripts.magpietts.evalset_config as evalset_config +import soundfile as sf +import torch +from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector, WhisperForConditionalGeneration, WhisperProcessor + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.metrics.wer import word_error_rate_detail +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.modules.fcd_metric import FrechetCodecDistance +from nemo.collections.tts.modules.utmosv2 import UTMOSv2Calculator + + +def find_generated_files(audio_dir, prefix, extension): + file_list = [] + for f in os.listdir(audio_dir): + if prefix in f and f.endswith(extension): + audio_number = int(f.split("_")[-1].split(extension)[0]) + file_list.append((audio_number, os.path.join(audio_dir, f))) + file_list.sort() + file_list = [t[1] for t in file_list] + return file_list + + +def find_generated_audio_files(audio_dir): + return find_generated_files(audio_dir=audio_dir, prefix="predicted_audio", extension=".wav") + + +def find_generated_codec_files(audio_dir): + return find_generated_files(audio_dir=audio_dir, prefix="predicted_codes", extension=".pt") + + +def get_wav_file_duration(audio_path: str) -> float: + """ + Get the duration of an WAV file in seconds. + """ + # get extension of the file + extension = os.path.splitext(audio_path)[1] + if extension.lower() != ".wav": + raise ValueError(f"Audio path {audio_path} is not a WAV file") + info = sf.info(audio_path) + seconds = info.frames / info.samplerate + return seconds + + +def read_manifest(manifest_path): + records = [] + with open(manifest_path, 'r') as f: + all_lines = f.readlines() + for line in all_lines: + line = line.strip() + records.append(json.loads(line)) + return records + + +def process_text(input_text): + # Convert text to lowercase + lower_case_text = input_text.lower() + + # Remove commas from text + no_comma_text = lower_case_text.replace(",", "") + + # Replace "-" with spaces + no_dash_text = no_comma_text.replace("-", " ") + + # Replace double spaces with single space + single_space_text = " ".join(no_dash_text.split()) + + single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) + + return single_space_text + + +def transcribe_with_whisper(whisper_model, whisper_processor, audio_path, language, device): + speech_array, sampling_rate = librosa.load(audio_path, sr=16000) + # Set the language task (optional, improves performance for specific languages) + forced_decoder_ids = ( + whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None + ) + inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features + inputs = inputs.to(device) + # Generate transcription + with torch.inference_mode(): + predicted_ids = whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) + + # Decode transcription + transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) + result = transcription[0] + return result + + +def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_seconds: float) -> np.ndarray: + """ + Pad audio to make it at least `min_seconds` long by adding silence at the end if needed. + """ + if audio_np.ndim != 1: + raise ValueError("Audio array must be 1D") + + n_samples = len(audio_np) + min_samples = round(min_seconds * sampling_rate) + + if n_samples < min_samples: + print(f"Padding audio from {n_samples/sampling_rate} seconds to {min_samples/sampling_rate} seconds") + padding_needed = min_samples - n_samples + audio_np = np.pad(audio_np, (0, padding_needed), mode='constant', constant_values=0) + return audio_np + + +@contextmanager +def nemo_log_level(level): + """ + A context manager that temporarily sets the logging level for the NeMo logger + and restores the original level when the context manager is exited. + + Args: + level (int): The logging level to set. + """ + logger = logging.getLogger("nemo_logger") + original_level = logger.level + logger.setLevel(level) + try: + yield + finally: + # restore the original level when the context manager is exited (even if an exception was raised) + logger.setLevel(original_level) + + +def extract_embedding(model, extractor, audio_path, device, sv_model_type): + speech_array, sampling_rate = librosa.load(audio_path, sr=16000) + # pad to 0.5 seconds as the extractor may not be able to handle very short signals + speech_array = pad_audio_to_min_length(speech_array, int(sampling_rate), min_seconds=0.5) + if sv_model_type == "wavlm": + inputs = extractor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_values.to(device) + with torch.inference_mode(): + embeddings = model(inputs).embeddings + else: # Titanet + with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: + # the embedding model doesn't accept NumPy arrays, so we write to a temporary file + sf.write(temp_file.name, speech_array, samplerate=16000) + with torch.inference_mode(): + embeddings = model.get_embedding(temp_file.name).squeeze() + + return embeddings.squeeze() + + +def compute_utmosv2_scores(audio_dir, device): + print(f"\nComputing UTMOSv2 scores for files in {audio_dir}...") + start_time = time.time() + utmosv2_calculator = UTMOSv2Calculator(device=device) + utmosv2_scores = utmosv2_calculator.process_directory(audio_dir) + # convert to to a dictionary indexed by file path + utmosv2_scores_dict = {os.path.normpath(item['file_path']): item['predicted_mos'] for item in utmosv2_scores} + end_time = time.time() + print(f"UTMOSv2 scores computed for {len(utmosv2_scores)} files in {end_time - start_time:.2f} seconds\n") + return utmosv2_scores_dict + + +def evaluate( + manifest_path, + audio_dir, + generated_audio_dir, + language="en", + sv_model_type="titanet", + asr_model_name="stt_en_conformer_transducer_large", + codecmodel_path=None, + with_utmosv2=True, +): + audio_file_lists = find_generated_audio_files(generated_audio_dir) + records = read_manifest(manifest_path) + assert len(audio_file_lists) == len(records) + if codecmodel_path is not None: + codes_file_lists = find_generated_codec_files(generated_audio_dir) + assert len(codes_file_lists) == len(records) + + device = "cuda" + + whisper_processor = None # Address CodeQL issue even though this varibable is only used when language != "en" + utmosv2_scores = None # Address CodeQL issue even though this varibable is only used when with_utmosv2 is true + if language == "en": + if asr_model_name.startswith("nvidia/") or asr_model_name in ["stt_en_conformer_transducer_large"]: + asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=asr_model_name) + else: + raise ValueError(f"ASR model {asr_model_name} not supported") + asr_model = asr_model.to(device) + asr_model.eval() + else: + whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + whisper_model = whisper_model.to(device) + whisper_model.eval() + + if sv_model_type == "wavlm": + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-base-plus-sv') + speaker_verification_model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv').to(device).eval() + else: + feature_extractor = None + speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name='titanet_large' + ) + speaker_verification_model = speaker_verification_model.to(device) + speaker_verification_model.eval() + with nemo_log_level(logging.ERROR): + # The model `titanet_small` prints thousands of lines during initialization, so suppress logs temporarily + print("Loading `titanet_small` model...") + speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name='titanet_small' + ) + speaker_verification_model_alternate = speaker_verification_model_alternate.to(device) + speaker_verification_model_alternate.eval() + + if codecmodel_path is not None: + codec = AudioCodecModel.restore_from(codecmodel_path, strict=False) + codec = codec.to(device) + codec.eval() + # The FCD metric measures a distance between generated and real codec frames. The distance + # is measured in the codec's embedding space. `codec_feature_dim` is the size of the codec's embedding vector. + # For example, for a group-FSQ codec with 8 codebooks with 4 values in each codebook, the embedding dimension is 8 x 4 = 32. + codec_feature_dim = codec.vector_quantizer.codebook_dim + fcd_metric = FrechetCodecDistance(codec=codec, feature_dim=codec_feature_dim).to(device) + else: + print("No codec model provided, skipping FCD metric") + fcd_metric = None + + if with_utmosv2: + utmosv2_scores = compute_utmosv2_scores(generated_audio_dir, device) + filewise_metrics = [] + pred_texts = [] + gt_texts = [] + gt_audio_texts = [] + total_generated_audio_seconds = 0.0 + for ridx, record in enumerate(records): + gt_audio_filepath = record['audio_filepath'] + context_audio_filepath = record.get('context_audio_filepath', None) + if audio_dir is not None: + gt_audio_filepath = os.path.join(audio_dir, gt_audio_filepath) + if context_audio_filepath is not None: + context_audio_filepath = os.path.join(audio_dir, context_audio_filepath) + # Update the FCD metric for *real* codes + if fcd_metric is not None: + fcd_metric.update_from_audio_file(gt_audio_filepath, True) + + pred_audio_filepath = audio_file_lists[ridx] + if fcd_metric is not None: + pred_codes_filepath = codes_file_lists[ridx] + + if with_utmosv2: + utmosv2_score = utmosv2_scores[os.path.normpath(pred_audio_filepath)] + else: + utmosv2_score = 0.0 + + try: + if language == "en": + with torch.inference_mode(): + pred_text = asr_model.transcribe([pred_audio_filepath])[0].text + pred_text = process_text(pred_text) + gt_audio_text = asr_model.transcribe([gt_audio_filepath])[0].text + gt_audio_text = process_text(gt_audio_text) + else: + pred_text = transcribe_with_whisper( + whisper_model, whisper_processor, pred_audio_filepath, language, device + ) + pred_text = process_text(pred_text) + gt_audio_text = transcribe_with_whisper( + whisper_model, whisper_processor, gt_audio_filepath, language, device + ) + gt_audio_text = process_text(gt_audio_text) + except Exception as e: + print("Error during ASR: {}".format(e)) + pred_text = "" + gt_audio_text = "" + + if "original_text" in record: + gt_text = process_text(record['original_text']) + elif 'normalized_text' in record: + gt_text = process_text(record['normalized_text']) + else: + gt_text = process_text(record['text']) + + detailed_cer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=True) + detailed_wer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=False) + + print("{} GT Text:".format(ridx), gt_text) + print("{} Pr Text:".format(ridx), pred_text) + # Format cer and wer to 2 decimal places + print("CER:", "{:.4f} | WER: {:.4f}".format(detailed_cer[0], detailed_wer[0])) + + pred_texts.append(pred_text) + gt_texts.append(gt_text) + gt_audio_texts.append(gt_audio_text) + + # update FCD metric + if fcd_metric is not None: + predicted_codes = torch.load(pred_codes_filepath).unsqueeze(0) # B, C, T + predicted_codes_lens = torch.tensor([predicted_codes.size(-1)], dtype=torch.int, device=device) + fcd_metric.update(predicted_codes, predicted_codes_lens, False) + + pred_context_ssim = 0.0 + gt_context_ssim = 0.0 + with torch.inference_mode(): + extract_embedding_fn = partial( + extract_embedding, + model=speaker_verification_model, + extractor=feature_extractor, + device=device, + sv_model_type=sv_model_type, + ) + extract_embedding_fn_alternate = partial( + extract_embedding, + model=speaker_verification_model_alternate, + extractor=feature_extractor, + device=device, + sv_model_type=sv_model_type, + ) + + # Ground truth vs. predicted + gt_speaker_embedding = extract_embedding_fn(audio_path=gt_audio_filepath) + pred_speaker_embedding = extract_embedding_fn(audio_path=pred_audio_filepath) + pred_gt_ssim = torch.nn.functional.cosine_similarity( + gt_speaker_embedding, pred_speaker_embedding, dim=0 + ).item() + + # Ground truth vs. predicted (alternate model) + gt_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=gt_audio_filepath) + pred_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=pred_audio_filepath) + pred_gt_ssim_alternate = torch.nn.functional.cosine_similarity( + gt_speaker_embedding_alternate, pred_speaker_embedding_alternate, dim=0 + ).item() + + if context_audio_filepath is not None: + context_speaker_embedding = extract_embedding_fn(audio_path=context_audio_filepath) + context_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=context_audio_filepath) + + # Predicted vs. context + pred_context_ssim = torch.nn.functional.cosine_similarity( + pred_speaker_embedding, context_speaker_embedding, dim=0 + ).item() + # Ground truth vs. context + gt_context_ssim = torch.nn.functional.cosine_similarity( + gt_speaker_embedding, context_speaker_embedding, dim=0 + ).item() + + # Predicted vs. context (alternate model) + pred_context_ssim_alternate = torch.nn.functional.cosine_similarity( + pred_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0 + ).item() + # Ground truth vs. context (alternate model) + gt_context_ssim_alternate = torch.nn.functional.cosine_similarity( + gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0 + ).item() + total_generated_audio_seconds += get_wav_file_duration(pred_audio_filepath) + + filewise_metrics.append( + { + 'gt_text': gt_text, + 'pred_text': pred_text, + 'gt_audio_text': gt_audio_text, + 'detailed_cer': detailed_cer, + 'detailed_wer': detailed_wer, + 'cer': detailed_cer[0], + 'wer': detailed_wer[0], + 'pred_gt_ssim': pred_gt_ssim, + 'pred_context_ssim': pred_context_ssim, + 'gt_context_ssim': gt_context_ssim, + 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, + 'pred_context_ssim_alternate': pred_context_ssim_alternate, + 'gt_context_ssim_alternate': gt_context_ssim_alternate, + 'gt_audio_filepath': gt_audio_filepath, + 'pred_audio_filepath': pred_audio_filepath, + 'context_audio_filepath': context_audio_filepath, + 'utmosv2': utmosv2_score, + } + ) + + filewise_metrics_keys_to_save = [ + 'cer', + 'wer', + 'pred_context_ssim', + 'pred_text', + 'gt_text', + 'gt_audio_filepath', + 'pred_audio_filepath', + 'context_audio_filepath', + ] + filtered_filewise_metrics = [] + for m in filewise_metrics: + filtered_filewise_metrics.append({k: m[k] for k in filewise_metrics_keys_to_save}) + + # Sort filewise metrics by cer in reverse + filewise_metrics.sort(key=lambda x: x['cer'], reverse=True) + + # compute frechet distance for the whole test set + if fcd_metric is not None: + fcd = fcd_metric.compute().cpu().item() + fcd_metric.reset() + else: + fcd = 0.0 + + avg_metrics = {} + avg_metrics['cer_filewise_avg'] = sum([m['detailed_cer'][0] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['wer_filewise_avg'] = sum([m['detailed_wer'][0] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['cer_cumulative'] = word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=True)[0] + avg_metrics['wer_cumulative'] = word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=False)[ + 0 + ] + avg_metrics['ssim_pred_gt_avg'] = sum([m['pred_gt_ssim'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['ssim_pred_context_avg'] = sum([m['pred_context_ssim'] for m in filewise_metrics]) / len( + filewise_metrics + ) + avg_metrics['ssim_gt_context_avg'] = sum([m['gt_context_ssim'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics['ssim_pred_gt_avg_alternate'] = sum([m['pred_gt_ssim_alternate'] for m in filewise_metrics]) / len( + filewise_metrics + ) + avg_metrics['ssim_pred_context_avg_alternate'] = sum( + [m['pred_context_ssim_alternate'] for m in filewise_metrics] + ) / len(filewise_metrics) + avg_metrics['ssim_gt_context_avg_alternate'] = sum( + [m['gt_context_ssim_alternate'] for m in filewise_metrics] + ) / len(filewise_metrics) + avg_metrics["cer_gt_audio_cumulative"] = word_error_rate_detail( + hypotheses=gt_audio_texts, references=gt_texts, use_cer=True + )[0] + avg_metrics["wer_gt_audio_cumulative"] = word_error_rate_detail( + hypotheses=gt_audio_texts, references=gt_texts, use_cer=False + )[0] + avg_metrics["frechet_codec_distance"] = fcd + avg_metrics["utmosv2_avg"] = sum([m['utmosv2'] for m in filewise_metrics]) / len(filewise_metrics) + avg_metrics["total_gen_audio_seconds"] = total_generated_audio_seconds + pprint.pprint(avg_metrics) + + return avg_metrics, filewise_metrics + + +def main(): + # audio_dir="/datap/misc/Datasets/riva" \ + parser = argparse.ArgumentParser(description='Evaluate Generated Audio') + parser.add_argument('--manifest_path', type=str, default=None) + parser.add_argument('--audio_dir', type=str, default=None) + parser.add_argument('--generated_audio_dir', type=str, default=None) + parser.add_argument('--whisper_language', type=str, default="en") + parser.add_argument('--evalset', type=str, default=None) + args = parser.parse_args() + + if args.evalset is not None: + dataset_meta_info = evalset_config.dataset_meta_info + assert args.evalset in dataset_meta_info + args.manifest_path = dataset_meta_info[args.evalset]['manifest_path'] + args.audio_dir = dataset_meta_info[args.evalset]['audio_dir'] + + evaluate( + args.manifest_path, + args.audio_dir, + args.generated_audio_dir, + args.whisper_language, + sv_model_type="wavlm", + asr_model_name="nvidia/parakeet-ctc-0.6b", + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py b/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py new file mode 100644 index 000000000000..56038108ad70 --- /dev/null +++ b/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py @@ -0,0 +1,912 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script extends the Lhotse shards with audio codec codes. + +Example of input shards: + $ tree ${CUTS_DIR} + ${CUTS_DIR}/ + cuts.000000.jsonl.gz + cuts.000001.jsonl.gz + ... + + $ tree ${TARGET_AUDIO_DIR} + ${TARGET_AUDIO_DIR}/ + recording.000000.tar + recording.000001.tar + ... + + $ tree ${CONTEXT_AUDIO_DIR} + ${CONTEXT_AUDIO_DIR}/ + recording.000000.tar + recording.000001.tar + ... + +Example usage: + export WANDB_API_KEY=${WANDB} + python -u ${CODE_DIR}/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py \ + --cuts-dir ${CUTS_DIR} \ + --target-audio-dir ${TARGET_AUDIO_DIR} \ + --context-audio-dir ${CONTEXT_AUDIO_DIR} \ + --output-dir ${RESULTS} \ + --codec-model-name ${CODEC_MODEL_NAME} \ + --codec-model-path ${CODEC_MODEL_PATH} \ + --codec-frame-rate ${CODEC_FRAME_RATE} \ + --devices ${DEVICES} \ + --num-nodes ${NUM_NODES} \ + --batch-size ${BATCH_SIZE} \ + --buffer-size ${BUFFER_SIZE} \ + --wandb-entity ${WANDB_ENTITY} \ + --wandb-project ${WANDB_PROJECT} \ + --wandb-name ${WANDB_NAME} \ + --log-level "DEBUG" \ + 2>&1 | tee ${LOG}/${WANDB_NAME}.stdout + +Expected output: + $ tree ${RESULTS} + ${RESULTS}/ + 21fpsCausalDecoder/ + target_codes/ + codes.000000.tar + codes.000001.tar + ... + context_codes/ + codes.000000.tar + codes.000001.tar + ... +""" + +import argparse +import glob +import logging +import os +import re +import threading +from collections import defaultdict +from concurrent.futures import Future, ThreadPoolExecutor +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import lightning.pytorch as pl +import torch +import wandb +from lhotse import CutSet +from lhotse.array import Array, TemporalArray +from lhotse.dataset import IterableDatasetWrapper, SimpleCutSampler +from lhotse.shar.writers.array import ArrayTarWriter +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import BasePredictionWriter +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.strategies import DDPStrategy +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from nemo.collections.tts.models import AudioCodecModel + + +def compute_effective_audio_length(original_audio_tensor: torch.Tensor, samples_per_frame: int) -> int: + """Computes the effective length of an audio tensor, padded to be a multiple of samples_per_frame.""" + original_len = original_audio_tensor.shape[0] + effective_len = original_len + if samples_per_frame > 0: + effective_len = ((original_len + samples_per_frame - 1) // samples_per_frame) * samples_per_frame + return effective_len + + +def collate_audio_vectors( + audio_list: List[torch.Tensor], audio_lens_list: List[int], padding_value: Union[float, int] +) -> torch.Tensor: + """ + Collate a list of audio vectors into a single tensor, handling padding for variable lengths. + Returns a padded tensor. + """ + assert all(len(t.shape) == 1 for t in audio_list), "Expected only 1-D input tensors." + assert len(audio_list) == len(audio_lens_list), "Expected the same number of audio vectors and lengths." + + # Create a padded tensor with the maximum audio length from audio_lens_list, where its max length could be longer than + # max length of `audio_list``. For example, `audio_lens_list` could be a multiple of the codec model samples per frame. + result = audio_list[0].new_ones(len(audio_lens_list), max(audio_lens_list)) * padding_value + for i, t in enumerate(audio_list): + result[i, : t.shape[0]] = t + return result + + +class AudioPairLhotseDataset(Dataset): + """ + A Lhotse Dataset that processes a batch of MonoCuts (received as a CutSet) + containing target and context audio. + Designed to be used with a Lhotse sampler yielding CutSet batches. + Handles loading audio and collating the batch within __getitem__. + """ + + def __init__(self, target_sample_rate: int, codec_model_samples_per_frame: int): + self.target_sample_rate = target_sample_rate + self.codec_model_samples_per_frame = codec_model_samples_per_frame + + def __getitem__(self, cuts: CutSet) -> Optional[Dict[str, Any]]: + original_target_audios_list = [] + effective_target_lengths_list = [] + original_context_audios_list = [] + effective_context_lengths_list = [] + target_cut_ids_list = [] + shard_indices_list = [] + + for cut in cuts: + if not cut.has_custom("shard_origin"): + err_msg = f"Cut {cut} is missing required key 'shard_origin'." + logging.error(err_msg) + raise ValueError(err_msg) + if not cut.has_custom("context_recording"): + err_msg = f"Cut {cut} is missing required key 'context_recording'." + logging.error(err_msg) + raise ValueError(err_msg) + + # Parse shard index from the custom field, handling potential errors + origin_path = cut.custom["shard_origin"] + match = re.search(r"cuts\.(\d+)\.jsonl\.gz$", origin_path) + if match is None: + raise ValueError(f"Could not parse shard index from shard_origin: {origin_path}") + shard_idx_origin = int(match.group(1)) + + # audio shape: (num_channels (1), num_samples) -> (num_samples) + # resample to target sample rate + target_audio = torch.from_numpy(cut.recording.resample(self.target_sample_rate).load_audio().squeeze(0)) + context_audio = torch.from_numpy( + cut.context_recording.resample(self.target_sample_rate).load_audio().squeeze(0) + ) + original_target_audios_list.append(target_audio) + original_context_audios_list.append(context_audio) + + eff_target_len = compute_effective_audio_length(target_audio, self.codec_model_samples_per_frame) + effective_target_lengths_list.append(eff_target_len) + + eff_context_len = compute_effective_audio_length(context_audio, self.codec_model_samples_per_frame) + effective_context_lengths_list.append(eff_context_len) + + target_cut_ids_list.append(cut.id) + shard_indices_list.append(shard_idx_origin) + + # Ensure lists are not empty before calling collate_audio_vectors. + if not original_target_audios_list: + err_msg = "AudioPairLhotseDataset.__getitem__ processed an empty CutSet or failed to load any audio data, resulting in an empty audio list." + logging.error(err_msg) + raise ValueError(err_msg) + + target_audio_padded_batch = collate_audio_vectors( + original_target_audios_list, effective_target_lengths_list, padding_value=0.0 + ) + context_audio_padded_batch = collate_audio_vectors( + original_context_audios_list, effective_context_lengths_list, padding_value=0.0 + ) + + # TODO: is it really necessary to convert lengths to torch.int64? currently applying torch.int32. + target_audio_lens_collated = torch.IntTensor(effective_target_lengths_list) + context_audio_lens_collated = torch.IntTensor(effective_context_lengths_list) + + return { + "target_audios": target_audio_padded_batch, + "target_audio_lens": target_audio_lens_collated, + "context_audios": context_audio_padded_batch, + "context_audio_lens": context_audio_lens_collated, + "target_cut_id": target_cut_ids_list, + "shard_idx_origin": shard_indices_list, + } + + +class CodecExtractor(pl.LightningModule): + """ + LightningModule to extract codec codes. Manages DataLoader creation and + distribution via predict_dataloader hook. + """ + + def __init__( + self, + model_path: str, + cuts_dir: str, + target_audio_dir: str, + context_audio_dir: str, + batch_size: int, + ): + super().__init__() + self.model_path = model_path + self.cuts_dir = Path(cuts_dir) + self.target_audio_dir = Path(target_audio_dir) + self.context_audio_dir = Path(context_audio_dir) + self.batch_size = batch_size + + logging.info(f"Initializing `AudioPairLhotseDataset` with model path: {self.model_path}") + # load the model. mapping to cpu is to avoid GPU mem spikes when initializing the model + self.codec_model = AudioCodecModel.restore_from(restore_path=self.model_path, map_location='cpu', strict=False) + self.codec_model.eval() + logging.info("Codec model loaded.") + + # Placeholder for the rank-specific list of dataloaders + self._rank_dataloaders: Optional[List[DataLoader]] = None + + def predict_dataloader(self) -> List[DataLoader]: + """ + Creates and returns the list of DataLoaders assigned to the current rank. + Caches the result to avoid redundant creation. + + This function is called by the Trainer to get the dataloaders for the current rank. This happens after + intializing `model.predict()` but before any actual prediction steps (ie. calls to `model.predict_step()`) are executed. + """ + # Return cached dataloaders if already created for this rank + if self._rank_dataloaders is not None: + return self._rank_dataloaders + + # Determine rank and world size + try: + # Prefer trainer attributes if available + current_global_rank = self.global_rank + world_size = self.trainer.world_size + except AttributeError: + # Fallback to torch.distributed if trainer attributes aren't set yet + current_global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + + logging.info(f"[Rank {current_global_rank}/{world_size}] Creating assigned subset of dataloaders...") + + # Find all shard files globally + cuts_shard_pattern = str(self.cuts_dir / "cuts.*.jsonl.gz") + all_cuts_shard_paths = sorted(glob.glob(cuts_shard_pattern)) + + if not all_cuts_shard_paths: + msg = f"[Rank {current_global_rank}/{world_size}] No input cut shards found matching pattern: {cuts_shard_pattern}. Cannot proceed." + logging.error(msg) + raise FileNotFoundError(msg) + + num_total_shards = len(all_cuts_shard_paths) + + # Verify shard indices are contiguous and start from 0 based on filenames (globally) + first_idx_str = re.search(r"cuts\.(\d+)\.jsonl\.gz$", all_cuts_shard_paths[0]).group(1) + last_idx_str = re.search(r"cuts\.(\d+)\.jsonl\.gz$", all_cuts_shard_paths[-1]).group(1) + first_idx = int(first_idx_str) + last_idx = int(last_idx_str) + expected_last_idx = num_total_shards - 1 + if first_idx != 0: + raise ValueError(f"Expected first shard index to be 0, but found {first_idx} in {all_cuts_shard_paths[0]}") + if last_idx != expected_last_idx: + raise ValueError( + f"Expected last shard index to be {expected_last_idx}, but found {last_idx} in {all_cuts_shard_paths[-1]}" + ) + logging.info( + f"[Rank {current_global_rank}/{world_size}] Verified {num_total_shards} total shard files globally, with indices from {first_idx} to {last_idx}." + ) + + # Calculate the slice of original shard indices assigned to this rank + is_distributed = world_size > 1 + assigned_shard_indices_for_rank = [] + + if num_total_shards > 0: + if not is_distributed: + assigned_shard_indices_for_rank = list(range(num_total_shards)) + logging.info( + f"[Rank {current_global_rank}/{world_size}] Non-distributed mode. Will process all {num_total_shards} shards." + ) + else: + num_per_rank_base = num_total_shards // world_size + num_with_extra = num_total_shards % world_size + + if current_global_rank < num_with_extra: + start_shard_offset = current_global_rank * (num_per_rank_base + 1) + num_shards_for_rank = num_per_rank_base + 1 + else: + # Offset by the shards handled by ranks with an extra one + start_shard_offset = num_with_extra + current_global_rank * num_per_rank_base + num_shards_for_rank = num_per_rank_base + + end_shard_offset = start_shard_offset + num_shards_for_rank + assigned_shard_indices_for_rank = list(range(start_shard_offset, end_shard_offset)) + + logging.info( + f"[Rank {current_global_rank}/{world_size}] Assigned original shard indices " + f"{start_shard_offset} through {end_shard_offset -1} " + f"({len(assigned_shard_indices_for_rank)} shards)" + ) + + if not assigned_shard_indices_for_rank: + logging.info( + f"[Rank {current_global_rank}/{world_size}] No shards assigned to this rank. Returning empty dataloader list. This usually happens when the number of shards is less than the number of ranks." + ) + self._rank_dataloaders = [] + return [] + + # Create DataLoaders only for the shards assigned to this rank + dataloaders_for_rank = [] + for original_shard_idx in tqdm( + assigned_shard_indices_for_rank, + total=len(assigned_shard_indices_for_rank), + desc=f">>> [Rank {current_global_rank}/{world_size}] Creating DataLoaders for its assigned shards", + ): + logging.debug(f"[Rank {current_global_rank}] Processing original shard {original_shard_idx}...") + fields = { + "cuts": [str(self.cuts_dir / f"cuts.{original_shard_idx:06d}.jsonl.gz")], + "recording": [str(self.target_audio_dir / f"recording.{original_shard_idx:06d}.tar")], + "context_recording": [str(self.context_audio_dir / f"recording.{original_shard_idx:06d}.tar")], + } + # Verify if all files exist + if not all(Path(shard_filepaths[0]).is_file() for shard_filepaths in fields.values()): + err_msg = f"[Rank {current_global_rank}/{world_size}] Missing one or more files for shard {original_shard_idx}. Files: {fields}" + logging.error(err_msg) + raise FileNotFoundError(err_msg) + + try: + logging.debug( + f"[Rank {current_global_rank}] Loading CutSet for original shard {original_shard_idx}..." + ) + shard_cutset = CutSet.from_shar(fields=fields) + logging.debug(f"[Rank {current_global_rank}] Loaded CutSet for original shard {original_shard_idx}.") + except Exception as e: + logging.critical( + f"[Rank {current_global_rank}/{world_size}] CRITICAL ERROR: Failed to load CutSet from shar for original shard index {original_shard_idx}. \ + Files attempted: {fields}. \ + Error: {e}", + exc_info=True, + ) + raise + + logging.debug(f"[Rank {current_global_rank}] Creating Sampler for original shard {original_shard_idx}...") + # Explicitly set rank=0, world_size=1 to ensure sampler iterates the whole shard_cutset + sampler = SimpleCutSampler( + shard_cutset, max_cuts=self.batch_size, shuffle=False, drop_last=False, rank=0, world_size=1 + ) + logging.debug(f"[Rank {current_global_rank}] Creating Dataset for original shard {original_shard_idx}...") + shard_dataset = AudioPairLhotseDataset( + target_sample_rate=self.codec_model.sample_rate, + codec_model_samples_per_frame=self.codec_model.samples_per_frame, + ) + logging.debug(f"[Rank {current_global_rank}] Wrapping Dataset for original shard {original_shard_idx}...") + iterable_dataset = IterableDatasetWrapper( + dataset=shard_dataset, + sampler=sampler, + ) + logging.debug( + f"[Rank {current_global_rank}] Creating DataLoader for original shard {original_shard_idx}..." + ) + dl = DataLoader( + dataset=iterable_dataset, + batch_size=None, + num_workers=1, # Keep num_workers=1 for `IterableDatasetWrapper + SimpleCutSampler` to avoid duplicate batches. + pin_memory=True, + ) + logging.debug( + f"[Rank {current_global_rank}] Appending DataLoader for original shard {original_shard_idx}..." + ) + dataloaders_for_rank.append(dl) + logging.debug(f"[Rank {current_global_rank}] Finished processing original shard {original_shard_idx}.") + + logging.info( + f"[Rank {current_global_rank}/{world_size}] Created {len(dataloaders_for_rank)} DataLoaders for this rank." + ) + # Cache the created dataloaders for this rank + self._rank_dataloaders = dataloaders_for_rank + return self._rank_dataloaders + + def forward( + self, + target_audios: torch.Tensor, + target_audio_lens: torch.Tensor, + context_audios: torch.Tensor, + context_audio_lens: torch.Tensor, + ) -> Optional[Dict[str, torch.Tensor]]: + try: + target_audios = target_audios.to(self.device) + target_audio_lens = target_audio_lens.to(self.device) + context_audios = context_audios.to(self.device) + context_audio_lens = context_audio_lens.to(self.device) + # NOTE: we avoided directly calling `self.codec_model.encode()` because it pads audios again. + with torch.inference_mode(): + target_audios_encoded, target_audios_encoded_len = self.codec_model.audio_encoder( + audio=target_audios, audio_len=target_audio_lens + ) + target_tokens = self.codec_model.quantize( + encoded=target_audios_encoded, encoded_len=target_audios_encoded_len + ) + context_audios_encoded, context_audios_encoded_len = self.codec_model.audio_encoder( + audio=context_audios, audio_len=context_audio_lens + ) + context_tokens = self.codec_model.quantize( + encoded=context_audios_encoded, encoded_len=context_audios_encoded_len + ) + return { + "target_codes": target_tokens.to(dtype=torch.int16, device="cpu"), + "target_codes_lengths": target_audios_encoded_len.to(device="cpu"), + "context_codes": context_tokens.to(dtype=torch.int16, device="cpu"), + "context_codes_lengths": context_audios_encoded_len.to(device="cpu"), + } + except Exception as e: + logging.error( + f"[Rank {self.global_rank}/{self.world_size}] Error during batched codec encoding: {e}", exc_info=True + ) + raise e + + def predict_step( + self, batch: Dict[str, Any], batch_idx: int, dataloader_idx: int = 0 + ) -> Optional[List[Dict[str, Any]]]: + codes_dict = self( + target_audios=batch["target_audios"], + target_audio_lens=batch["target_audio_lens"], + context_audios=batch["context_audios"], + context_audio_lens=batch["context_audio_lens"], + ) + + target_codes_batch = codes_dict["target_codes"] + target_codes_lens = codes_dict["target_codes_lengths"] + context_codes_batch = codes_dict["context_codes"] + context_codes_lens = codes_dict["context_codes_lengths"] + + target_cut_ids = batch["target_cut_id"] + shard_indices_in_batch = batch["shard_idx_origin"] + + # The shard_indices list should ideally contain the *same* original index + # for all items in a batch, because each DataLoader loads from only one shard. + results = [] + batch_size = batch["target_audios"].shape[0] + original_shard_idx = shard_indices_in_batch[0] + if not all(idx == original_shard_idx for idx in shard_indices_in_batch): + raise ValueError( + f"Inconsistent shard indices within batch! Batch Index: {batch_idx}, Dataloader Index: {dataloader_idx}. Indices: {shard_indices_in_batch}." + ) + + if len(target_cut_ids) != batch_size or target_codes_batch.shape[0] != batch_size: + raise ValueError( + f"Batch size mismatch after inference! Input IDs: {len(target_cut_ids)}, " + f"Input Audio Batch: {batch_size}, Output Codes Batch: {target_codes_batch.shape[0]}. " + f"Batch Index: {batch_idx}, Dataloader Index: {dataloader_idx}" + ) + + for target_cut_id, target_codes, context_codes, target_codes_len, context_codes_len in zip( + target_cut_ids, target_codes_batch, context_codes_batch, target_codes_lens, context_codes_lens + ): + results.append( + { + "target_cut_id": target_cut_id, + "shard_idx": original_shard_idx, + "target_codes": target_codes[:, :target_codes_len], + "context_codes": context_codes[:, :context_codes_len], + } + ) + + return results + + +class CodecPredictionWriter(BasePredictionWriter): + """ + Writes codec predictions (target and context codes) to ArrayTarWriter shards asynchronously. + Uses a ThreadPoolExecutor with a single worker to serialize writes and closing operations per shard, + allowing potential overlap between prediction computation and I/O while closing writers early. + """ + + def __init__( + self, + output_dir: str, + codec_model_name: str, + codec_frame_rate: float, + ): + super().__init__(write_interval="batch") + self.output_dir_base = Path(output_dir) + self.codec_model_name = codec_model_name + self.codec_frame_rate = codec_frame_rate + self.rank: int = -1 + self.world_size: int = -1 + self.target_writers: Dict[int, ArrayTarWriter] = {} + self.context_writers: Dict[int, ArrayTarWriter] = {} + self.target_codes_dir: Optional[Path] = None + self.context_codes_dir: Optional[Path] = None + + # Attributes for asynchronous writing and closing + self.writer_lock: Optional[threading.Lock] = None + self.bg_worker_thread: Optional[ThreadPoolExecutor] = None + self.futures_per_shard: Optional[Dict[int, List[Future]]] = None + self.closer_futures: Optional[List[Future]] = None # Futures for the _wait_and_close_worker tasks + self.last_processed_shard_idx: int = -1 + + def setup(self, trainer: Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None: + self.rank = trainer.global_rank + self.world_size = trainer.world_size + logging.info( + f"[Rank {self.rank}/{self.world_size}] Setting up CodecPredictionWriter for async writing with early close." + ) + + # Initialize async components + self.writer_lock = threading.Lock() + # Single worker ensures sequential execution of writes AND closes + self.bg_worker_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f'CodecWriterRank{self.rank}') + self.futures_per_shard = defaultdict(list) + self.closer_futures = [] + self.last_processed_shard_idx = -1 + + # Create directories + self.target_codes_dir = self.output_dir_base / self.codec_model_name / "target_codes" + self.context_codes_dir = self.output_dir_base / self.codec_model_name / "context_codes" + if self.rank == 0: + self.target_codes_dir.mkdir(parents=True, exist_ok=True) + self.context_codes_dir.mkdir(parents=True, exist_ok=True) + if trainer.world_size > 1: + torch.distributed.barrier() + logging.info(f"[Rank {self.rank}/{self.world_size}] Setup complete. Writers will be created on demand.") + + def _get_or_create_writer( + self, writer_dict: Dict[int, ArrayTarWriter], shard_idx: int, base_dir: Path + ) -> ArrayTarWriter: + # Lock needed as this might be called from main thread while closer task modifies dicts + with self.writer_lock: + if shard_idx not in writer_dict: + output_filename = str(base_dir / f"codes.{shard_idx:06d}.tar") + logging.debug( + f"[Rank {self.rank}/{self.world_size}] Creating writer for shard {shard_idx} (Thread-safe check): {output_filename}" + ) + try: + writer = ArrayTarWriter(pattern=output_filename, shard_size=None, compression="numpy") + writer.__enter__() + writer_dict[shard_idx] = writer + logging.info(f"[Rank {self.rank}/{self.world_size}] Created writer for shard {shard_idx}") + except Exception as e: + msg = f"[Rank {self.rank}/{self.world_size}] Failed to create writer for shard {shard_idx} (file: {output_filename}): {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + # Return writer even if it might be closed soon by a background task + # The background task handles the actual closing. + return writer_dict[shard_idx] + + def _write_worker( + self, + target_cut_id: str, + shard_idx: int, + target_codes: torch.Tensor, + context_codes: torch.Tensor, + target_writer: ArrayTarWriter, + context_writer: ArrayTarWriter, + ): + """Worker function executed by the background thread to write a single item.""" + # Assuming target_writer and context_writer are valid when this task starts + try: + target_codes_array_manifest = TemporalArray( + array=Array(storage_type="shar", storage_path="", storage_key="", shape=list(target_codes.shape)), + temporal_dim=-1, + frame_shift=1 / self.codec_frame_rate, + start=0, + ) + context_codes_array_manifest = TemporalArray( + array=Array(storage_type="shar", storage_path="", storage_key="", shape=list(context_codes.shape)), + temporal_dim=-1, + frame_shift=1 / self.codec_frame_rate, + start=0, + ) + target_writer.write(key=target_cut_id, value=target_codes.numpy(), manifest=target_codes_array_manifest) + context_writer.write(key=target_cut_id, value=context_codes.numpy(), manifest=context_codes_array_manifest) + logging.debug(f"[Worker Rank {self.rank}] Wrote item {target_cut_id} for shard {shard_idx}") + except Exception as e: + msg = f"[Worker Rank {self.rank}] CRITICAL I/O ERROR writing item {target_cut_id} for shard {shard_idx}: {e}. Writer might be closed prematurely?" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + def _wait_and_close_worker(self, shard_idx_to_close: int): + """Waits for all write tasks of a shard, then closes and removes its writers.""" + logging.info(f"[Worker Rank {self.rank}] Starting closure process for shard {shard_idx_to_close}") + # 1. Retrieve and remove the list of write futures for this shard + # Do this early to prevent new futures being added for this closing shard? + # No, write_on_batch_end logic prevents submission for old shards. + write_futures = self.futures_per_shard.pop(shard_idx_to_close, []) + + # 2. Wait for all write operations for this shard to complete + logging.info( + f"[Worker Rank {self.rank}] Waiting for {len(write_futures)} write tasks for shard {shard_idx_to_close}..." + ) + processed_write_futures = 0 + if write_futures: + for f in write_futures: + try: + f.result() # Wait for completion + processed_write_futures += 1 + except Exception as e: + # Write worker already logged this, but log context here + logging.error( + f"[Worker Rank {self.rank}] Exception during write future.result() for shard {shard_idx_to_close}: {e}", + exc_info=False, + ) + logging.info( + f"[Worker Rank {self.rank}] Completed {processed_write_futures}/{len(write_futures)} write tasks for shard {shard_idx_to_close}." + ) + else: + logging.warning( + f"[Worker Rank {self.rank}] No write futures found to wait for shard {shard_idx_to_close} during close." + ) + + # 3. Safely remove and close the writers + writers_closed_count = 0 + with self.writer_lock: # Protect access to the writer dictionaries + target_writer = self.target_writers.pop(shard_idx_to_close, None) + context_writer = self.context_writers.pop(shard_idx_to_close, None) + + if target_writer: + try: + target_writer.close() + logging.info(f"[Worker Rank {self.rank}] Closed target writer for shard {shard_idx_to_close}.") + writers_closed_count += 1 + except Exception as e: + logging.error( + f"[Worker Rank {self.rank}] Error closing target writer for shard {shard_idx_to_close}: {e}", + exc_info=True, + ) + else: + logging.warning( + f"[Worker Rank {self.rank}] Target writer for shard {shard_idx_to_close} not found during close." + ) + + if context_writer: + try: + context_writer.close() + logging.info(f"[Worker Rank {self.rank}] Closed context writer for shard {shard_idx_to_close}.") + writers_closed_count += 1 + except Exception as e: + logging.error( + f"[Worker Rank {self.rank}] Error closing context writer for shard {shard_idx_to_close}: {e}", + exc_info=True, + ) + else: + logging.warning( + f"[Worker Rank {self.rank}] Context writer for shard {shard_idx_to_close} not found during close." + ) + + logging.info( + f"[Worker Rank {self.rank}] Finished closure process for shard {shard_idx_to_close}. Closed {writers_closed_count} writers." + ) + + def write_on_batch_end( + self, + trainer: Trainer, + pl_module: pl.LightningModule, + predictions: Optional[List[Dict[str, Any]]], + batch_indices: Optional[List[int]], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + + if not predictions: + err_msg = f"[Rank {self.rank}/{self.world_size}] Received empty predictions list for batch_idx {batch_idx}, dataloader_idx {dataloader_idx}. Skipping." + logging.error(err_msg) + raise ValueError(err_msg) + + current_shard_idx = predictions[0]["shard_idx"] + if not all(p["shard_idx"] == current_shard_idx for p in predictions): + raise ValueError( + f"[Rank {self.rank}] Inconsistent shard indices within batch! Batch Index: {batch_idx}, Dataloader Index: {dataloader_idx}." + ) + + # Check for shard change and submit closer task for the previous shard + if current_shard_idx != self.last_processed_shard_idx and self.last_processed_shard_idx != -1: + logging.info( + f"[Rank {self.rank}] Shard index changed from {self.last_processed_shard_idx} to {current_shard_idx}. " + f"Submitting closure task for shard {self.last_processed_shard_idx}." + ) + try: + closer_future = self.bg_worker_thread.submit( + self._wait_and_close_worker, self.last_processed_shard_idx + ) + self.closer_futures.append(closer_future) + except Exception as e: + msg = f"[Rank {self.rank}] Failed to submit closer task for shard {self.last_processed_shard_idx}: {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + self.last_processed_shard_idx = current_shard_idx + + # Submit write tasks for each item in the current batch + for prediction in predictions: + try: + target_cut_id = prediction["target_cut_id"] + shard_idx = prediction["shard_idx"] + target_codes = prediction["target_codes"] + context_codes = prediction["context_codes"] + + # This needs the lock because the closer task might be removing entries concurrently + target_writer = self._get_or_create_writer(self.target_writers, shard_idx, self.target_codes_dir) + context_writer = self._get_or_create_writer(self.context_writers, shard_idx, self.context_codes_dir) + + # Submit the writing task + write_future = self.bg_worker_thread.submit( + self._write_worker, + target_cut_id, + shard_idx, + target_codes, + context_codes, + target_writer, + context_writer, + ) + self.futures_per_shard[shard_idx].append(write_future) + logging.debug(f"[Rank {self.rank}] Submitted write task for item {target_cut_id}, shard {shard_idx}") + + except Exception as e: + msg = f"[Rank {self.rank}] Error processing prediction item {prediction.get('target_cut_id', 'UNKNOWN')} from batch {batch_idx}: {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + def teardown(self, trainer: Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None: + logging.info( + f"[Rank {self.rank}/{self.world_size}] Tearing down CodecPredictionWriter. Handling final shard and waiting for closers..." + ) + + # 1. Submit closer task for the very last processed shard (if any) + final_shard_processed = self.last_processed_shard_idx + if final_shard_processed != -1 and final_shard_processed in self.futures_per_shard: + logging.info( + f"[Rank {self.rank}] Submitting final closure task for last processed shard {final_shard_processed}." + ) + try: + closer_future = self.bg_worker_thread.submit(self._wait_and_close_worker, final_shard_processed) + self.closer_futures.append(closer_future) + except Exception as e: + msg = f"[Rank {self.rank}] Failed to submit final closer task for shard {final_shard_processed}: {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + # 2. Wait for all closer tasks to complete + num_closer_futures = len(self.closer_futures) + logging.info( + f"[Rank {self.rank}/{self.world_size}] Waiting for {num_closer_futures} background closer tasks to complete." + ) + processed_closer_futures = 0 + if self.closer_futures: + for future in tqdm( + self.closer_futures, + total=num_closer_futures, + desc=f"[Rank {self.rank}/{self.world_size}] Finalizing Shard Closures", + leave=False, + ): + try: + future.result() # Wait and check for exceptions from the closer worker + processed_closer_futures += 1 + except Exception as e: + msg = f"[Rank {self.rank}/{self.world_size}] Exception caught during closer future.result(): {e}" + logging.error(msg, exc_info=True) + raise ValueError(msg) + + logging.info( + f"[Rank {self.rank}/{self.world_size}] Completed {processed_closer_futures}/{num_closer_futures} closer tasks." + ) + else: + logging.info(f"[Rank {self.rank}/{self.world_size}] No closer tasks were submitted.") + + # 3. Shutdown the executor gracefully (all tasks should be done now) + if self.bg_worker_thread: + logging.info(f"[Rank {self.rank}/{self.world_size}] Shutting down background worker thread.") + self.bg_worker_thread.shutdown(wait=True) + self.bg_worker_thread = None + + # 4. Final sanity checks and cleanup + remaining_writers = len(self.target_writers) + len(self.context_writers) + if remaining_writers > 0: + msg = f"[Rank {self.rank}/{self.world_size}] {remaining_writers} writers remain after teardown! This should not happen. Keys: Target {list(self.target_writers.keys())}, Context {list(self.context_writers.keys())}" + logging.error(msg) + raise ValueError(msg) + + remaining_futures = sum(len(futs) for futs in self.futures_per_shard.values()) + if remaining_futures > 0: + msg = f"[Rank {self.rank}/{self.world_size}] {remaining_futures} write futures remain after teardown! This should not happen. Shards: {list(self.futures_per_shard.keys())}" + logging.error(msg) + raise ValueError(msg) + + self.target_writers.clear() + self.context_writers.clear() + self.futures_per_shard.clear() + self.closer_futures.clear() + + logging.info(f"[Rank {self.rank}/{self.world_size}] Teardown complete.") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuts-dir", type=str, required=True, help="Directory containing input cuts/cuts.*.jsonl.gz shards." + ) + parser.add_argument( + "--target-audio-dir", type=str, required=True, help="Directory containing target_audio/recording.*.tar shards." + ) + parser.add_argument( + "--context-audio-dir", + type=str, + required=True, + help="Directory containing context_audio/recording.*.tar shards.", + ) + parser.add_argument("--output-dir", type=str, required=True, help="Base directory to save the output code shards.") + parser.add_argument( + "--codec-model-name", + type=str, + default="21fpsCausalDecoder", + help="Name for codec model (used in output path).", + ) + parser.add_argument( + "--codec-model-path", type=str, required=True, help="Path to the NeMo codec model (.nemo file)." + ) + parser.add_argument("--codec-frame-rate", type=float, default=21.5, help="Frame rate for codec model.") + parser.add_argument("--devices", type=int, default=-1, help="Number of GPUs per node (-1 for all).") + parser.add_argument("--num-nodes", type=int, default=1, help="Number of nodes for distributed processing.") + parser.add_argument("--batch-size", type=int, default=32, help="Batch size PER GPU for codec inference.") + parser.add_argument( + "--buffer-size", type=int, default=256, help="Number of items to buffer before writing to TAR files." + ) + parser.add_argument("--wandb-entity", type=str, default=None, help="Wandb entity.") + parser.add_argument("--wandb-project", type=str, default="lhotse_codes_extraction", help="Wandb project.") + parser.add_argument("--wandb-name", type=str, default=None, help="Wandb run name.") + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level.", + ) + args = parser.parse_args() + + log_level_val = getattr(logging, args.log_level.upper(), logging.INFO) + log_format = '%(asctime)s - PID:%(process)d - %(levelname)s - %(message)s' + logging.basicConfig(level=log_level_val, format=log_format) + + codec_extractor = CodecExtractor( + model_path=args.codec_model_path, + cuts_dir=args.cuts_dir, + target_audio_dir=args.target_audio_dir, + context_audio_dir=args.context_audio_dir, + batch_size=args.batch_size, + ) + + pred_writer = CodecPredictionWriter( + output_dir=args.output_dir, + codec_model_name=args.codec_model_name, + codec_frame_rate=args.codec_frame_rate, + ) + + wandb_logger = None + if args.wandb_entity and args.wandb_project: + wandb_logger = WandbLogger( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.wandb_name or f"extract_codes_{args.codec_model_name}_{os.path.basename(args.cuts_dir)}", + log_model=False, + ) + logging.info(f"Wandb logging enabled to {args.wandb_entity}/{args.wandb_project}") + + strategy = DDPStrategy(find_unused_parameters=False) if torch.cuda.is_available() and args.devices != 1 else "auto" + trainer = Trainer( + devices=args.devices if torch.cuda.is_available() else 1, + num_nodes=args.num_nodes, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + strategy=strategy, + logger=wandb_logger, + callbacks=[pred_writer], + use_distributed_sampler=False, # we should disable replacing or wrapping Lhostse CutSampler with a `DistributedSamplerWrapper` since Lhotse's sampler already handles distributed sampling. + ) + + logging.info(f"Starting prediction with {trainer.world_size} ranks.") + trainer.predict(codec_extractor, return_predictions=False) + logging.info("Prediction finished.") + + if trainer.is_global_zero and wandb_logger: + wandb.finish() + logging.info("Wandb run finished.") + + +if __name__ == "__main__": + import torch.multiprocessing + + try: + torch.multiprocessing.set_start_method('spawn') + except RuntimeError: + # This exception occurs if the start method has already been set. We can safely ignore it. + pass + main() diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py new file mode 100755 index 000000000000..001c8b307883 --- /dev/null +++ b/scripts/magpietts/infer_and_evaluate.py @@ -0,0 +1,808 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Inference and Evaluation script used for CI and NeMo model evaluation on custom datasets. +Please use this script as an example of how to do inference with MagpieTTS, but this script is otherwise unsupported +for general use cases. +""" +import argparse +import copy +import glob +import json +import os +import shutil +import time +from functools import partial +from pathlib import Path +from typing import List + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import scipy.stats as stats +import scripts.magpietts.evalset_config as evalset_config +import scripts.magpietts.evaluate_generated_audio as evaluate_generated_audio +import soundfile as sf +import torch +from omegaconf.omegaconf import OmegaConf, open_dict +from PIL import Image + +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer +from nemo.collections.tts.data.text_to_speech_dataset import MagpieTTSDataset +from nemo.collections.tts.models import MagpieTTSModel + +# EVALUATION_DATASETS is the full list of datasets for evaluation of a new model. +EVALUATION_DATASETS = ( + "riva_hard_digits,riva_hard_letters,riva_hard_money,riva_hard_short,vctk,libritts_seen,libritts_test_clean" +) + + +def compute_mean_and_confidence_interval(metrics_list, metric_keys, confidence=0.90): + metrics = {} + for key in metric_keys: + measurements = [m[key] for m in metrics_list] + mean = np.mean(measurements) + std_err = stats.sem(measurements) + + confidence_interval = std_err * stats.t.ppf((1 + confidence) / 2, len(measurements) - 1) + print(f"{key}: {mean} +/- {confidence_interval}") + metrics[key] = "{:.4f} +/- {:.4f}".format(mean, confidence_interval) + return metrics + + +def update_config(model_cfg, codecmodel_path, legacy_codebooks=False, legacy_text_conditioning=False): + '''helper function to rename older yamls from t5 to magpie''' + model_cfg.codecmodel_path = codecmodel_path + if hasattr(model_cfg, 'text_tokenizer'): + # Backward compatibility for models trained with absolute paths in text_tokenizer + model_cfg.text_tokenizer.g2p.phoneme_dict = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + model_cfg.text_tokenizer.g2p.heteronyms = "scripts/tts_dataset_files/heteronyms-052722" + model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0 + model_cfg.train_ds = None + model_cfg.validation_ds = None + model_cfg.legacy_text_conditioning = legacy_text_conditioning + if "t5_encoder" in model_cfg: + model_cfg.encoder = model_cfg.t5_encoder + del model_cfg.t5_encoder + if "t5_decoder" in model_cfg: + model_cfg.decoder = model_cfg.t5_decoder + del model_cfg.t5_decoder + if hasattr(model_cfg, 'decoder') and hasattr(model_cfg.decoder, 'prior_eps'): + # Added to prevent crash after removing arg from transformer_2501.py in https://github.com/blisc/NeMo/pull/56 + del model_cfg.decoder.prior_eps + if hasattr(model_cfg, 'use_local_transformer') and model_cfg.use_local_transformer: + # For older checkpoints trained with a different parameter name + model_cfg.local_transformer_type = "autoregressive" + del model_cfg.use_local_transformer + if hasattr(model_cfg, 'downsample_factor'): + # Backward compatibility for models trained with the config option`downsample_factor` which was later renamed to `frame_stacking_factor` + model_cfg.frame_stacking_factor = model_cfg.downsample_factor + del model_cfg.downsample_factor + if legacy_codebooks: + # Added to address backward compatibility arising from + # https://github.com/blisc/NeMo/pull/64 + print( + "WARNING: Using legacy codebook indices for backward compatibility. Should only be used with old checkpoints." + ) + num_audio_tokens_per_codebook = model_cfg.num_audio_tokens_per_codebook + model_cfg.forced_num_all_tokens_per_codebook = num_audio_tokens_per_codebook + model_cfg.forced_audio_eos_id = num_audio_tokens_per_codebook - 1 + model_cfg.forced_audio_bos_id = num_audio_tokens_per_codebook - 2 + if model_cfg.model_type == 'decoder_context_tts': + model_cfg.forced_context_audio_eos_id = num_audio_tokens_per_codebook - 3 + model_cfg.forced_context_audio_bos_id = num_audio_tokens_per_codebook - 4 + model_cfg.forced_mask_token_id = num_audio_tokens_per_codebook - 5 + else: + model_cfg.forced_context_audio_eos_id = num_audio_tokens_per_codebook - 1 + model_cfg.forced_context_audio_bos_id = num_audio_tokens_per_codebook - 2 + if hasattr(model_cfg, 'sample_rate'): + # This was removed from the config and is now in the model class + sample_rate = model_cfg.sample_rate + del model_cfg.sample_rate + else: + sample_rate = None + return model_cfg, sample_rate + + +def update_ckpt(state_dict): + new_state_dict = {} + for key in state_dict.keys(): + if 't5_encoder' in key: + new_key = key.replace('t5_encoder', 'encoder') + new_state_dict[new_key] = state_dict[key] + elif 't5_decoder' in key: + new_key = key.replace('t5_decoder', 'decoder') + new_state_dict[new_key] = state_dict[key] + else: + new_state_dict[key] = state_dict[key] + return new_state_dict + + +def delete_old_generated_files(output_dir): + # Delete any leftover generated files from previous runs as these can confuse the evaluation + print(f"Deleting old generated files in: {output_dir} ...") + for f in glob.glob(f"{output_dir}/predicted_codes*.pt"): + os.remove(f) + for f in glob.glob(f"{output_dir}/predicted_audio*.wav"): + os.remove(f) + for f in glob.glob(f"{output_dir}/cross_attn_map_*.png"): + os.remove(f) + + +def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: str): + # Create dataframe from list of dicts + df = pd.DataFrame(metrics) + + # Plot the violin plots for all DataFrames side by side + num_columns = len(metric_keys) + width = num_columns * 5 + fig, axs = plt.subplots(1, num_columns, figsize=(width, 4)) + + for i, column in enumerate(metric_keys): + assert column in df + # Create empty lists to store the parts objects for each DataFrame + # Plot the violin plots for each DataFrame + axs[i].violinplot(df[column], showmedians=True, positions=[i], widths=0.5) + + axs[i].set_title(column) + axs[i].set_xticks([i]) + axs[i].set_xticklabels([column]) + axs[i].grid(True, linestyle="dotted") + + # Calculate and display the mean value for each DataFrame + mean = df[column].mean() + sem = df[column].sem() + axs[i].plot(i, mean, "o", color="red", markersize=4, label="Mean (95%CI)") + + label_numeric = f"{mean:.2f}±{1.96 * sem:.2f}" + axs[i].text(i + 0.06, mean, label_numeric, ha="center", va="top") + + # Create a single legend for all subplots + handles, labels = axs[0].get_legend_handles_labels() + fig.legend(handles, labels, loc="upper left") + + plt.tight_layout() + plt.savefig(output_png, format="png", bbox_inches="tight") + + +def create_combined_violin_plots(dataset_metrics: dict, metric_keys: List[str], output_png: str): + """ + Create box plots comparing multiple datasets for each metric in a single figure. + + Args: + dataset_metrics: Dictionary where keys are dataset names and values are lists of metric dictionaries + metric_keys: List of metric names to plot + output_png: Output file path for the combined plot + """ + # Prepare data for plotting + datasets = list(dataset_metrics.keys()) + num_datasets = len(datasets) + num_metrics = len(metric_keys) + + # Create figure with subplots for each metric + fig, axs = plt.subplots(1, num_metrics, figsize=(num_metrics * 6, 6)) + + # Handle case where there's only one metric (axs won't be an array) + if num_metrics == 1: + axs = [axs] + + # Define colors for different datasets + colors = plt.cm.Set3(np.linspace(0, 1, num_datasets)) + + for metric_idx, metric in enumerate(metric_keys): + ax = axs[metric_idx] + + # Collect data for all datasets for this metric + all_data = [] + positions = [] + dataset_labels = [] + + for dataset_idx, dataset in enumerate(datasets): + df = pd.DataFrame(dataset_metrics[dataset]) + if metric in df.columns: + data = df[metric].dropna() + all_data.append(data) + positions.append(dataset_idx + 1) + dataset_labels.append(dataset) + + # Create box plots + if all_data: + bp = ax.boxplot( + all_data, + positions=positions, + widths=0.6, + patch_artist=True, + showmeans=True, + meanline=False, + meanprops={'marker': 'o', 'markerfacecolor': 'red', 'markeredgecolor': 'red', 'markersize': 6}, + ) + + # Color the box plots + for i, patch in enumerate(bp['boxes']): + patch.set_facecolor(colors[i]) + patch.set_alpha(0.7) + + # Add mean labels for each dataset + for i, (data, pos) in enumerate(zip(all_data, positions)): + mean = data.mean() + sem = data.sem() + + label_numeric = f"{mean:.3f}±{1.96 * sem:.3f}" + ax.text(pos + 0.1, mean, label_numeric, ha="left", va="center", fontsize=8) + + # Set labels and title + ax.set_title(f"{metric.upper()}", fontsize=12, fontweight='bold') + ax.set_xticks(positions) + ax.set_xticklabels(dataset_labels, rotation=45, ha='right') + ax.grid(True, linestyle="dotted", alpha=0.7) + ax.set_xlabel("Dataset") + ax.set_ylabel(metric) + + # Set y-axis limit for CER metrics + if 'cer' in metric.lower(): + ax.set_ylim(0, 0.3) + + # Add overall title + fig.suptitle("Performance Comparison Across Datasets", fontsize=14, fontweight='bold') + + # Adjust layout and save + plt.tight_layout() + plt.savefig(output_png, format="png", bbox_inches="tight", dpi=300) + plt.close() + print(f"Combined violin plot saved to: {output_png}") + + +def run_inference( + hparams_file, + checkpoint_file, + nemo_file, + datasets, + out_dir, + temperature, + topk, + codecmodel_path, + use_cfg, + cfg_scale, + batch_size, + sv_model, + asr_model_name, + num_repeats=1, + apply_attention_prior=False, + attention_prior_epsilon=1e-3, + attention_prior_lookahead_window=10, + estimate_alignment_from_layers=None, + apply_prior_to_layers=None, + start_prior_after_n_audio_steps=10, + confidence_level=0.95, + use_local_transformer=False, + maskgit_n_steps=3, + maskgit_noise_scale=0.0, + maskgit_fixed_schedule=None, + maskgit_sampling_type=None, + legacy_codebooks=False, + legacy_text_conditioning=False, + clean_up_disk=False, + hparams_file_from_wandb=False, + log_exp_name=False, + compute_fcd=False, + violin_plot_metrics=None, + eos_detection_method=None, + ignore_finished_sentence_tracking=False, + with_utmosv2=True, +): + # Avoid lists as default values and apply default value in function + if violin_plot_metrics is None: + violin_plot_metrics = ['cer', 'pred_context_ssim', 'utmosv2'] + # Load model + if hparams_file is not None and checkpoint_file is not None: + model_cfg = OmegaConf.load(hparams_file) + if "cfg" in model_cfg: + model_cfg = model_cfg.cfg + + if hparams_file_from_wandb: + model_cfg = model_cfg.value + + with open_dict(model_cfg): + model_cfg, cfg_sample_rate = update_config( + model_cfg, codecmodel_path, legacy_codebooks, legacy_text_conditioning + ) + + model = MagpieTTSModel(cfg=model_cfg) + model.use_kv_cache_for_inference = True + + # Load weights from checkpoint file + print("Loading weights from checkpoint") + ckpt = torch.load(checkpoint_file, weights_only=False) + state_dict = update_ckpt(ckpt['state_dict']) + model.load_state_dict(state_dict) + checkpoint_name = checkpoint_file.split("/")[-1].split(".ckpt")[0] + elif nemo_file is not None: + model_cfg = MagpieTTSModel.restore_from(nemo_file, return_config=True) + with open_dict(model_cfg): + model_cfg, cfg_sample_rate = update_config( + model_cfg, codecmodel_path, legacy_codebooks, legacy_text_conditioning + ) + model = MagpieTTSModel.restore_from(nemo_file, override_config_path=model_cfg) + model.use_kv_cache_for_inference = True + checkpoint_name = nemo_file.split("/")[-1].split(".nemo")[0] + else: + raise ValueError("Need either a checkpoint and hparams file, or a nemo file.") + + if cfg_sample_rate is not None and cfg_sample_rate != model.sample_rate: + raise ValueError("Sample rate in config and model do not match") + + print("Loaded weights.") + model.cuda() + model.eval() + + if log_exp_name: + # the experiment name is the name of the directory two above the checkpoint path, + # since training produces directories of the form `exp_name/checkpoints/checkpoint_name.ckpt`. + exp_name = f"{os.path.basename(os.path.dirname(os.path.dirname(checkpoint_file)))}__" + else: + exp_name = "" + + # Build checkpoint name + checkpoint_name = ( + f"{exp_name}{checkpoint_name}_Temp{temperature}_Topk{topk}_Cfg_{use_cfg}_{cfg_scale}_" + f"Prior_{apply_attention_prior}_" + ) + if apply_attention_prior: + # Only add prior config details if prior is enabled (to avoid super long checkpoint names) + checkpoint_name += ( + f"{attention_prior_epsilon}_{attention_prior_lookahead_window}_{start_prior_after_n_audio_steps}_" + f"{''.join([str(l) for l in estimate_alignment_from_layers]) if estimate_alignment_from_layers is not None else 'None'}_" + f"{''.join([str(l) for l in apply_prior_to_layers]) if apply_prior_to_layers is not None else 'None'}_" + ) + checkpoint_name += ( + f"LT_{use_local_transformer}_" + f"MaskGit_{maskgit_n_steps}_{maskgit_sampling_type}_{''.join([str(l) for l in maskgit_fixed_schedule]) if maskgit_fixed_schedule is not None else 'None'}_" + f"SV_{sv_model}" + f"EOS_{eos_detection_method}" + f"IgnoreFST_{ignore_finished_sentence_tracking}" + ) + + dataset_meta_info = evalset_config.dataset_meta_info + ssim_per_dataset = [] + cer_per_dataset = [] + all_datasets_filewise_metrics = {} # Store filewise metrics for all datasets for combined violin plot + if (not with_utmosv2) and ('utmosv2' in violin_plot_metrics): + violin_plot_metrics.remove('utmosv2') + for dataset in datasets: + print(f"Evaluating dataset {dataset}") + metrics_n_repeated = [] + manifest_records = read_manifest(dataset_meta_info[dataset]['manifest_path']) + language = dataset_meta_info[dataset].get('whisper_language', 'en') + dataset_meta_for_dl = copy.deepcopy(dataset_meta_info[dataset]) + for key in ["whisper_language", "load_cached_codes_if_available"]: + if key in dataset_meta_for_dl: + del dataset_meta_for_dl[key] + + dataset_meta = {dataset: dataset_meta_for_dl} + + eval_dir = os.path.join(out_dir, f"{checkpoint_name}_{dataset}") + audio_dir = os.path.join(eval_dir, "audio") + all_experiment_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") + os.makedirs(eval_dir, exist_ok=True) + + if not os.path.exists(all_experiment_csv): + with open(all_experiment_csv, "w") as f: + header = "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,utmosv2,total_gen_audio_seconds" + if compute_fcd: + header += ",frechet_codec_distance" + header += "\n" + f.write(header) + + context_duration_min = model.cfg.get('context_duration_min', 5.0) + context_duration_max = model.cfg.get('context_duration_max', 5.0) + if context_duration_min < 5.0 and context_duration_max > 5.0: + context_duration_min = 5.0 + context_duration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though. + + dataset_filewise_metrics_all_repeats = [] # Store metrics for all repeats of this dataset + for repeat_idx in range(num_repeats): + pred_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") + os.makedirs(pred_audio_dir, exist_ok=True) + delete_old_generated_files(pred_audio_dir) + + test_dataset = MagpieTTSDataset( + dataset_meta=dataset_meta, + sample_rate=model.sample_rate, + min_duration=0.5, + max_duration=20, + codec_model_samples_per_frame=model.codec_model_samples_per_frame, + bos_id=model.bos_id, + eos_id=model.eos_id, + context_audio_bos_id=model.context_audio_bos_id, + context_audio_eos_id=model.context_audio_eos_id, + audio_bos_id=model.audio_bos_id, + audio_eos_id=model.audio_eos_id, + num_audio_codebooks=model.num_audio_codebooks, + prior_scaling_factor=None, + load_cached_codes_if_available=False, + dataset_type='test', + tokenizer_config=None, + load_16khz_audio=model.model_type == 'single_encoder_sv_tts', + use_text_conditioning_tokenizer=model.use_text_conditioning_encoder, + text_conditioning_tokenizer_name=model.text_conditioning_tokenizer_name, + pad_context_text_to_max_duration=model.pad_context_text_to_max_duration, + context_duration_min=context_duration_min, + context_duration_max=context_duration_max, + ) + assert len(test_dataset) == len( + manifest_records + ), f"Dataset length and manifest length should be the same. Dataset length: {len(test_dataset)}, Manifest length: {len(manifest_records)}" + + test_dataset.text_tokenizer = model.tokenizer + # Set phoneme prob = 1 for g2p + g2p = None + if isinstance(model.tokenizer, AggregatedTTSTokenizer): + g2p = model.tokenizer.tokenizers["english_phoneme"].g2p + elif isinstance(model.tokenizer, IPATokenizer): + g2p = model.tokenizer.g2p + if g2p is not None: + g2p.phoneme_probability = 1.0 + + test_data_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=batch_size, + collate_fn=test_dataset.collate_fn, + num_workers=2, + shuffle=False, + ) + + item_idx = 0 + all_rtf_metrics = [] + codec_file_paths = [] + for bidx, batch in enumerate(test_data_loader): + print(f"Processing batch {bidx} out of {len(test_data_loader)} of dataset {dataset}") + batch_cuda = {} + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch_cuda[key] = batch[key].cuda() + else: + batch_cuda[key] = batch[key] + + st = time.time() + ( + predicted_audio, + predicted_audio_lens, + predicted_codes, + predicted_codes_lens, + rtf_metrics, + cross_attention_maps, + _, + ) = model.infer_batch( + batch_cuda, + max_decoder_steps=440, + temperature=temperature, + topk=topk, + use_cfg=use_cfg, + cfg_scale=cfg_scale, + return_cross_attn_probs=True, + apply_attention_prior=apply_attention_prior, + prior_epsilon=attention_prior_epsilon, + lookahead_window_size=attention_prior_lookahead_window, + estimate_alignment_from_layers=estimate_alignment_from_layers, + apply_prior_to_layers=apply_prior_to_layers, + start_prior_after_n_audio_steps=start_prior_after_n_audio_steps, + use_local_transformer_for_inference=use_local_transformer, + maskgit_n_steps=maskgit_n_steps, + maskgit_noise_scale=maskgit_noise_scale, + maskgit_fixed_schedule=maskgit_fixed_schedule, + maskgit_sampling_type=maskgit_sampling_type, + ignore_finished_sentence_tracking=ignore_finished_sentence_tracking, + eos_detection_method=eos_detection_method, + ) + + all_rtf_metrics.append(rtf_metrics) + et = time.time() + print(f"Time taken for inference: {et-st}", predicted_audio.size()) + for idx in range(predicted_audio.size(0)): + cross_attn_map_image = Image.fromarray(cross_attention_maps[idx]) + cross_attn_map_image.save(os.path.join(pred_audio_dir, f"cross_attn_map_{item_idx}.png")) + + predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() + predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] + audio_path = os.path.join(pred_audio_dir, f"predicted_audio_{item_idx}.wav") + sf.write(audio_path, predicted_audio_np, model.sample_rate) + codes_path = os.path.join(pred_audio_dir, f"predicted_codes_{item_idx}.pt") + predicted_codes_current = predicted_codes[idx, :, : predicted_codes_lens[idx]] # C, T' + torch.save(predicted_codes_current, codes_path) + codec_file_paths.append(codes_path) + context_audio_path = manifest_records[item_idx].get('context_audio_filepath', None) + target_audio_path = manifest_records[item_idx].get('audio_filepath', None) + if context_audio_path is not None: + context_audio_path = os.path.join(dataset_meta_info[dataset]['audio_dir'], context_audio_path) + if target_audio_path is not None: + target_audio_path = os.path.join(dataset_meta_info[dataset]['audio_dir'], target_audio_path) + if os.path.exists(context_audio_path): + shutil.copy(context_audio_path, os.path.join(audio_dir, f"context_audio_{item_idx}.wav")) + if os.path.exists(target_audio_path): + shutil.copy(target_audio_path, os.path.join(audio_dir, f"target_audio_{item_idx}.wav")) + item_idx += 1 + + mean_rtf_metrics = {} + for key in all_rtf_metrics[0]: + mean_rtf_metrics[key] = float(np.mean([m[key] for m in all_rtf_metrics])) + + metrics, filewise_metrics = evaluate_generated_audio.evaluate( + dataset_meta[dataset]['manifest_path'], + dataset_meta[dataset]['audio_dir'], + pred_audio_dir, + language=language, + sv_model_type=sv_model, + asr_model_name=asr_model_name, + codecmodel_path=codecmodel_path if compute_fcd else None, + with_utmosv2=with_utmosv2, + ) + metrics_n_repeated.append(metrics) + dataset_filewise_metrics_all_repeats.extend( + filewise_metrics + ) # Collect all filewise metrics for combined plot + + with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: + json.dump(metrics, f, indent=4) + + with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f: + # Indent for better readability + json.dump(filewise_metrics, f, indent=4) + + with open(os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}.json"), "w") as f: + json.dump(mean_rtf_metrics, f, indent=4) + + with open(all_experiment_csv, "a") as f: + data = f"{checkpoint_name},{dataset},{metrics['cer_filewise_avg']},{metrics['wer_filewise_avg']},{metrics['cer_cumulative']},{metrics['wer_cumulative']},{metrics['ssim_pred_gt_avg']},{metrics['ssim_pred_context_avg']},{metrics['ssim_gt_context_avg']},{metrics['ssim_pred_gt_avg_alternate']},{metrics['ssim_pred_context_avg_alternate']},{metrics['ssim_gt_context_avg_alternate']},{metrics['cer_gt_audio_cumulative']},{metrics['wer_gt_audio_cumulative']},{metrics['utmosv2_avg']},{metrics['total_gen_audio_seconds']}" + if compute_fcd: + data += f",{metrics['frechet_codec_distance']}" + data += "\n" + f.write(data) + print(f"Wrote metrics for {checkpoint_name} and {dataset} to {all_experiment_csv}") + + output_png_file = Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png" + create_violin_plots(filewise_metrics, violin_plot_metrics, output_png_file) + + # Clean up temporary codec files + for codes_file in codec_file_paths: + os.remove(codes_file) + + # Store filewise metrics for this dataset for combined plotting + all_datasets_filewise_metrics[dataset] = dataset_filewise_metrics_all_repeats + + metric_keys = [ + 'cer_filewise_avg', + 'wer_filewise_avg', + 'cer_cumulative', + 'wer_cumulative', + 'ssim_pred_gt_avg', + 'ssim_pred_context_avg', + 'ssim_gt_context_avg', + 'ssim_pred_gt_avg_alternate', + 'ssim_pred_context_avg_alternate', + 'ssim_gt_context_avg_alternate', + 'cer_gt_audio_cumulative', + 'wer_gt_audio_cumulative', + 'utmosv2_avg', + 'total_gen_audio_seconds', + ] + if compute_fcd: + metric_keys.append('frechet_codec_distance') + metrics_mean_ci = compute_mean_and_confidence_interval( + metrics_n_repeated, metric_keys, confidence=confidence_level + ) + all_experiment_csv_with_ci = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") + if not os.path.exists(all_experiment_csv_with_ci): + with open(all_experiment_csv_with_ci, "w") as f: + header = "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative,wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,utmosv2_avg,total_gen_audio_seconds" + if compute_fcd: + header += ",frechet_codec_distance" + header += "\n" + f.write(header) + with open(all_experiment_csv_with_ci, "a") as f: + data = f"{checkpoint_name},{dataset},{metrics_mean_ci['cer_filewise_avg']},{metrics_mean_ci['wer_filewise_avg']},{metrics_mean_ci['cer_cumulative']},{metrics_mean_ci['wer_cumulative']},{metrics_mean_ci['ssim_pred_gt_avg']},{metrics_mean_ci['ssim_pred_context_avg']},{metrics_mean_ci['ssim_gt_context_avg']},{metrics_mean_ci['ssim_pred_gt_avg_alternate']},{metrics_mean_ci['ssim_pred_context_avg_alternate']},{metrics_mean_ci['ssim_gt_context_avg_alternate']},{metrics_mean_ci['cer_gt_audio_cumulative']},{metrics_mean_ci['wer_gt_audio_cumulative']},{metrics_mean_ci['utmosv2_avg']},{metrics_mean_ci['total_gen_audio_seconds']}" + if compute_fcd: + data += f",{metrics_mean_ci['frechet_codec_distance']}" + data += "\n" + f.write(data) + print(f"Wrote metrics with CI for {checkpoint_name} and {dataset} to {all_experiment_csv_with_ci}") + + measurements = [m['ssim_pred_context_avg'] for m in metrics_n_repeated] + ssim_current = np.mean(measurements) + ssim_per_dataset.append(ssim_current) + measurements = [m['cer_cumulative'] for m in metrics_n_repeated] + cer_current = np.mean(measurements) + cer_per_dataset.append(cer_current) + + # Create combined violin plot for all datasets + if len(all_datasets_filewise_metrics) > 1: # Only create combined plot if we have multiple datasets + combined_output_png = os.path.join(out_dir, f"{checkpoint_name}_combined_violin_plot.png") + create_combined_violin_plots(all_datasets_filewise_metrics, violin_plot_metrics, combined_output_png) + + # Average across datasets + ssim = np.mean(ssim_per_dataset) + cer = np.mean(cer_per_dataset) + if clean_up_disk: + shutil.rmtree(out_dir) + return cer, ssim + + +def main(): + parser = argparse.ArgumentParser(description='Experiment Evaluation') + parser.add_argument('--hparams_files', type=str, default=None) + parser.add_argument('--hparams_file_from_wandb', action='store_true') + parser.add_argument('--checkpoint_files', type=str, default=None) + parser.add_argument('--nemo_files', type=str, default=None) + parser.add_argument('--codecmodel_path', type=str, default=None, help="Path to codec model") + parser.add_argument('--datasets', type=str, default=None) + # Parameters for running inference experiments locally + parser.add_argument('--out_dir', type=str, default="/datap/misc/Evals/LocalTransformerAblations2") + parser.add_argument('--temperature', type=float, default=0.6) + parser.add_argument('--use_cfg', action='store_true') + parser.add_argument( + '--use_local_transformer', + action='store_true', + help="Enables use of local transformer for inference; applies to both Autoregressive and MaskGit sampling.", + ) + parser.add_argument('--maskgit_n_steps', type=int, default=3) + parser.add_argument('--maskgit_noise_scale', type=float, default=0.0) + parser.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None) + parser.add_argument( + '--maskgit_sampling_type', default=None, choices=["default", "causal", "purity_causal", "purity_default"] + ) + parser.add_argument('--cfg_scale', type=float, default=2.5) + parser.add_argument('--apply_attention_prior', action='store_true') + parser.add_argument('--attention_prior_epsilon', type=float, default=0.1) + parser.add_argument('--attention_prior_lookahead_window', type=int, default=5) + parser.add_argument('--estimate_alignment_from_layers', type=str, default=None) + parser.add_argument('--apply_prior_to_layers', type=str, default=None) + parser.add_argument('--start_prior_after_n_audio_steps', type=int, default=0) + parser.add_argument('--topk', type=int, default=80) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument( + '--eos_detection_method', + type=str, + default="argmax_or_multinomial_any", + choices=[ + "argmax_any", + "argmax_or_multinomial_any", + "argmax_all", + "argmax_or_multinomial_all", + "argmax_zero_cb", + "argmax_or_multinomial_zero_cb", + ], + ) + # Parameters for evaluation + parser.add_argument('--sv_model', type=str, default="titanet") # titanet, wavlm + parser.add_argument( + '--asr_model_name', type=str, default="nvidia/parakeet-tdt-1.1b" + ) # stt_en_conformer_transducer_large, nvidia/parakeet-ctc-0.6b + parser.add_argument('--num_repeats', type=int, default=1) + parser.add_argument('--confidence_level', type=float, default=0.95) + parser.add_argument('--legacy_codebooks', action='store_true') + parser.add_argument('--legacy_text_conditioning', action='store_true') + parser.add_argument('--ignore_finished_sentence_tracking', action='store_true') + parser.add_argument('--clean_up_disk', action='store_true') + parser.add_argument('--cer_target', type=float, default=None) + parser.add_argument('--ssim_target', type=float, default=None) + parser.add_argument('--disable_utmosv2', action='store_true', help="Disable UTMOSv2 computation") + parser.add_argument( + '--log_exp_name', + action='store_true', + help="Include the experiment name (derived from the checkpoint path) in the output folder name.", + ) + parser.add_argument('--disable_fcd', action='store_true', help="Disable Frechet Codec Distance computation") + parser.add_argument( + '--violin_plot_metrics', + type=str, + nargs='*', + default=['cer', 'pred_context_ssim', 'utmosv2'], + help="Which metrics to add the violin plot.", + ) + args = parser.parse_args() + + if args.datasets is None: + args.datasets = EVALUATION_DATASETS + + # FCD computation is enabled by default, disabled only when --disable_fcd is specified + compute_fcd = not args.disable_fcd + + estimate_alignment_from_layers = None + if args.estimate_alignment_from_layers is not None: + estimate_alignment_from_layers = [int(l.strip()) for l in args.estimate_alignment_from_layers.split(",")] + apply_prior_to_layers = None + if args.apply_prior_to_layers is not None: + apply_prior_to_layers = [int(l.strip()) for l in args.apply_prior_to_layers.split(",")] + + run_inference_w_args = partial( + run_inference, + datasets=args.datasets.split(","), + out_dir=args.out_dir, + temperature=args.temperature, + topk=args.topk, + codecmodel_path=args.codecmodel_path, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + batch_size=args.batch_size, + sv_model=args.sv_model, + asr_model_name=args.asr_model_name, + num_repeats=args.num_repeats, + apply_attention_prior=args.apply_attention_prior, + attention_prior_epsilon=args.attention_prior_epsilon, + attention_prior_lookahead_window=args.attention_prior_lookahead_window, + estimate_alignment_from_layers=estimate_alignment_from_layers, + apply_prior_to_layers=apply_prior_to_layers, + start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, + confidence_level=args.confidence_level, + use_local_transformer=args.use_local_transformer, + maskgit_n_steps=args.maskgit_n_steps, + maskgit_noise_scale=args.maskgit_noise_scale, + maskgit_fixed_schedule=args.maskgit_fixed_schedule, + maskgit_sampling_type=args.maskgit_sampling_type, + legacy_codebooks=args.legacy_codebooks, + legacy_text_conditioning=args.legacy_text_conditioning, + clean_up_disk=args.clean_up_disk, + hparams_file_from_wandb=args.hparams_file_from_wandb, + log_exp_name=args.log_exp_name, + compute_fcd=compute_fcd, + violin_plot_metrics=args.violin_plot_metrics, + eos_detection_method=args.eos_detection_method, + ignore_finished_sentence_tracking=args.ignore_finished_sentence_tracking, + with_utmosv2=not args.disable_utmosv2, + ) + + cer, ssim = None, None + # Mode 1: Run inference from provided hparams and checkpoint files + if ( + (args.hparams_files is not None) + and (args.checkpoint_files is not None) + and (args.hparams_files != "null") + and (args.checkpoint_files != "null") + ): + hparam_files = args.hparams_files.split(",") + checkpoint_files = args.checkpoint_files.split(",") + print("Running inference for hparams files: ", hparam_files) + print("Running inference for checkpoint files: ", checkpoint_files) + assert len(hparam_files) == len( + checkpoint_files + ), "Number of hparams files and checkpoint files should be the same." + for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): + cer, ssim = run_inference_w_args( + hparams_file=hparams_file, + checkpoint_file=checkpoint_file, + nemo_file=None, + ) + # Mode 2: Run inference from a .nemo file + elif args.nemo_files: + print(f"Running inference for nemo file: {args.nemo_files}") + for nemo_file in args.nemo_files.split(","): + cer, ssim = run_inference_w_args( + hparams_file=None, + checkpoint_file=None, + nemo_file=nemo_file, + ) + else: + parser.error( + "You must provide a model to run. Please specify either:\n" + "1. --hparams_files and --checkpoint_files\n" + "2. --nemo_file\n" + ) + if cer is not None and args.cer_target is not None and cer > float(args.cer_target): + raise ValueError() + if ssim is not None and args.ssim_target is not None and ssim < float(args.ssim_target): + raise ValueError() + + +if __name__ == '__main__': + main() diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index 4df335d59843..ef608b3d4ec6 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -1968,35 +1968,6 @@ def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: assert -5.0 < cut.tracks[1].snr < 5.0 -def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): - config = OmegaConf.create( - { - "cuts_path": str(cutset_path), - "noise_path": str(nemo_manifest_path), - "noise_mix_prob": 1.0, - "noise_snr": [-5.0, 5.0], - "batch_size": 2, - "seed": 0, - "shard_seed": 0, - } - ) - dl = get_lhotse_dataloader_from_config( - config=config, - global_rank=0, - world_size=1, - dataset=Identity(), - ) - batch = next(iter(dl)) - assert isinstance(batch, CutSet) - assert len(batch) == 2 - cut = batch[0] - assert isinstance(cut, MixedCut) - assert -5.0 < cut.tracks[1].snr < 5.0 - cut = batch[1] - assert isinstance(cut, MixedCut) - assert -5.0 < cut.tracks[1].snr < 5.0 - - def test_dataloader_with_noise_lhotse_jsonl(cutset_path: Path): config = OmegaConf.create( { diff --git a/tests/collections/common/test_lhotse_tts_filters.py b/tests/collections/common/test_lhotse_tts_filters.py new file mode 100644 index 000000000000..877b0886ab97 --- /dev/null +++ b/tests/collections/common/test_lhotse_tts_filters.py @@ -0,0 +1,143 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from lhotse import SupervisionSegment +from lhotse.array import Array, TemporalArray +from lhotse.audio import AudioSource, Recording +from lhotse.cut import MonoCut + +from nemo.collections.common.data.lhotse.sampling import ( + CERFilter, + ContextSpeakerSimilarityFilter, + ValidationStatusFilter, +) + + +@pytest.fixture +def cut_example(): + cut = MonoCut( + id='cut-rec-Zdud2gXLTXY-238.16-6.88_repeat0', + start=238.16, + duration=6.88, + channel=0, + supervisions=[ + SupervisionSegment( + id='sup-rec-Zdud2gXLTXY', + recording_id='rec-Zdud2gXLTXY', + start=238.16, + duration=6.88, + channel=0, + text='and in like manner, as do other parts in which there appears to exist an adaptation to an end.', + language='en', + speaker='| Language:en Dataset:nvyt2505 Speaker:Zdud2gXLTXY_SPEAKER_02 |', + gender=None, + custom={ + 'cer': 0.03, + 'bandwidth': 10875, + 'stoi_squim': 0.921, + 'sisdr_squim': 15.17, + 'pesq_squim': 1.845, + 'dataset_id': '5a6446c5-6114-4380-b875-9de17fda2b8d', + 'dataset_version': '2024_11_07_131440', + 'dataset_name': 'yt_mixed', + 'context_speaker_similarity': 0.9172529578208923, + 'context_audio_offset': 7001.95659375, + 'context_audio_duration': 14.64, + 'context_audio_text': 'Uat gives an excellent illustration of the effects of a course of selection, which may be considered as unconscious, insofar that the breeders could never have expected, or even wished, to produce the result which ensued,', + 'context_recording_id': 'rec-Zdud2gXLTXY', + }, + alignment=None, + ) + ], + features=None, + recording=Recording( + id='rec-Zdud2gXLTXY', + sources=[AudioSource(type='file', channels=[0], source='/audio/Zdud2gXLTXY.wav')], + sampling_rate=22050, + num_samples=952064173, + duration=43177.51351473923, + channel_ids=[0], + transforms=None, + ), + custom={ + 'validation_status': 'pass', + 'target_audio': Recording( + id='cut-rec-Zdud2gXLTXY-238.16-6.88', + sources=[AudioSource(type='memory', channels=[0], source='')], + sampling_rate=22050, + num_samples=151704, + duration=6.88, + channel_ids=[0], + transforms=None, + ), + 'context_audio': Recording( + id='context_cut-rec-Zdud2gXLTXY-7001.96-14.64', + sources=[AudioSource(type='memory', channels=[0], source='')], + sampling_rate=22050, + num_samples=322812, + duration=14.64, + channel_ids=[0], + transforms=None, + ), + 'target_codes': TemporalArray( + array=Array(storage_type='memory_npy', storage_path='', storage_key='', shape=[8, 149]), + temporal_dim=-1, + frame_shift=0.046511627906976744, + start=0, + ), + 'context_codes': TemporalArray( + array=Array(storage_type='memory_npy', storage_path='', storage_key='', shape=[8, 316]), + temporal_dim=-1, + frame_shift=0.046511627906976744, + start=0, + ), + 'shard_origin': '/cuts/cuts.000001.jsonl.gz', + 'shar_epoch': 0, + 'tokenizer_names': ['english_phoneme'], + }, + ) + return cut + + +def test_cut_cer_filter(cut_example): + f = CERFilter(0.4) + assert f(cut_example) == True + + f = CERFilter(0.01) + assert f(cut_example) == False + + f = CERFilter(float("inf")) + assert f(cut_example) == True + + +def test_cut_context_speaker_similarity_filter(cut_example): + f = ContextSpeakerSimilarityFilter(0.6) + assert f(cut_example) == True + + f = ContextSpeakerSimilarityFilter(0.95) + assert f(cut_example) == False + + f = ContextSpeakerSimilarityFilter(-1) + assert f(cut_example) == True + + +def test_cut_validation_status_filter(cut_example): + f = ValidationStatusFilter("pass") + assert f(cut_example) == True + + f = ValidationStatusFilter("wrong_text") + assert f(cut_example) == False + + f = ValidationStatusFilter("any_other_status") + assert f(cut_example) == False diff --git a/tests/collections/tts/modules/test_fcd_metric.py b/tests/collections/tts/modules/test_fcd_metric.py new file mode 100644 index 000000000000..dfbcbe1f90e0 --- /dev/null +++ b/tests/collections/tts/modules/test_fcd_metric.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.modules.fcd_metric import FrechetCodecDistance + + +class TestFrechetCodecDistance: + @pytest.fixture + def device(self): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + @pytest.fixture + def codec(self, device, scope="session"): + return AudioCodecModel.from_pretrained("nvidia/low-frame-rate-speech-codec-22khz").to(device) + + @pytest.fixture + def metric(self, codec, device): + codec_feature_dim = codec.vector_quantizer.codebook_dim + return FrechetCodecDistance(codec=codec, feature_dim=codec_feature_dim).to(device) + + @pytest.mark.unit + def test_same_distribution(self, metric, device, codec): + """Test that FCD is close to zero when comparing identical distributions.""" + B, C, T = 3, codec.num_codebooks, 20 + codes = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes_len = torch.randint(low=1, high=T, size=(B,), device=device) + + # Update with same codes for both real and fake + metric.update(codes, codes_len, is_real=True) + metric.update(codes, codes_len, is_real=False) + + eps = 0.01 + fcd = metric.compute() + assert fcd < eps and fcd >= 0, f"FCD value is {fcd} but should be close to 0" + metric.reset() + + @pytest.mark.unit + def test_different_distribution(self, metric, device, codec): + """Test that FCD is positive when comparing different distributions.""" + B, C, T = 3, codec.num_codebooks, 20 + + # Generate two different sets of codes + codes1 = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes2 = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes_len = torch.randint(low=1, high=T, size=(B,), device=device) + + metric.update(codes1, codes_len, is_real=True) + metric.update(codes2, codes_len, is_real=False) + + fcd = metric.compute() + assert fcd > 0, f"FCD value is {fcd} but should be positive for different distributions" + metric.reset() + + def test_empty_distribution(self, metric): + """Test that computing the FCD on empty distributions returns 0.""" + fcd = metric.compute() + assert fcd == 0.0 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.unit + def test_gpu_compatibility(self, metric, device, codec): + """Test that the metric works correctly on GPU.""" + assert metric.device.type == "cuda" + B, C, T = 3, codec.num_codebooks, 20 + codes = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes_len = torch.randint(low=1, high=T, size=(B,), device=device) + + metric.update(codes, codes_len, is_real=True) + metric.update(codes, codes_len, is_real=False) + + fcd = metric.compute() + + eps = 0.01 + assert isinstance(fcd, torch.Tensor) + assert fcd.device.type == "cuda" + assert fcd < eps and fcd >= 0, f"FCD value is {fcd} but should be close to 0" + + @pytest.mark.unit + def test_update_from_audio_file(self, metric): + """Test the update_from_audio_file method.""" + + # Test with both "real" and "fake" audio files (different files) + metric.update_from_audio_file("tests/.data/tts/mini_ljspeech/wavs/LJ019-0373.wav", is_real=True) + metric.update_from_audio_file("tests/.data/tts/mini_ljspeech/wavs/LJ050-0234.wav", is_real=False) + + fcd = metric.compute() + assert isinstance(fcd, torch.Tensor) + assert fcd > 0, f"FCD value is {fcd} but should be positive given that we tested different audio files" + + @pytest.mark.unit + def test_empty_codes_update(self, metric, device): + """Test that the FCD metric doesn't crash when provided with empty codes.""" + B, C, T = 1, 0, 100 + codes = torch.ones(B, C, T, device=device) + codes_len = T * torch.ones(B, device=device) + # if it crashes PyTest will report it + metric.update(codes, codes_len, is_real=True) + + @pytest.mark.unit + def test_codebooks_mismatch_update(self, metric, device, codec): + """Test that the FCD metric doesn't crash when provided with incorrect number of codebooks.""" + B = 2 + C = codec.num_codebooks - 1 # intentionally missing one codebook + T = 10 + codes = torch.ones(B, C, T, device=device) + codes_len = T * torch.ones(B, device=device, dtype=torch.long) + # if it crashes PyTest will report it + metric.update(codes, codes_len, is_real=True) diff --git a/tests/collections/tts/modules/test_transformer_2501.py b/tests/collections/tts/modules/test_transformer_2501.py index b7f486028aea..606ce12bf324 100644 --- a/tests/collections/tts/modules/test_transformer_2501.py +++ b/tests/collections/tts/modules/test_transformer_2501.py @@ -37,6 +37,9 @@ def set_seed(seed): random.seed(seed) +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths + + @pytest.mark.unit class TestConvolutionLayer: @classmethod @@ -53,6 +56,7 @@ def setup_class(cls): [-1.0317, 1.6818, 1.4257, -0.5003, -1.7254, 0.8830, -0.4541, -0.4631, -0.0986, 0.5083], [-0.3231, -1.0899, 0.5774, 0.1661, 0.9620, -2.3307, -0.6158, -0.3663, 1.2469, -1.0208]]] ) + cls.input_mask = torch.ones(1, cls.input_tensor.shape[2]) # fmt:on def test_non_causal_forward(self): @@ -68,7 +72,7 @@ def test_non_causal_forward(self): ) with torch.no_grad(): - output_tensor = layer(self.input_tensor) + output_tensor = layer(self.input_tensor, self.input_mask) # fmt:off expected_output_tensor = torch.Tensor( @@ -96,7 +100,7 @@ def test_causal_forward(self): ) with torch.no_grad(): - output_tensor = layer(self.input_tensor) + output_tensor = layer(self.input_tensor, self.input_mask) # fmt:off expected_output_tensor = torch.Tensor( @@ -133,6 +137,7 @@ def setup_class(cls): [-0.1543, 0.3365, 1.7475], [-0.1753, 0.4115, 0.0772]]] ) + cls.input_mask = torch.ones(1, cls.input_tensor.shape[1]) # fmt:on def test_causal_forward(self): @@ -142,7 +147,7 @@ def test_causal_forward(self): ) with torch.no_grad(): - output_tensor = layer(self.input_tensor) + output_tensor = layer(self.input_tensor, self.input_mask) # fmt:off expected_output_tensor = torch.Tensor( @@ -168,7 +173,7 @@ def test_non_causal_forward(self): ) with torch.no_grad(): - output_tensor = layer(self.input_tensor) + output_tensor = layer(self.input_tensor, self.input_mask) # fmt:off expected_output_tensor = torch.Tensor( @@ -795,3 +800,58 @@ def test_forward_causal_self_attn_and_has_xattn(self): expected_output["attn_probabilities"][i]["cross_attn_probabilities"][0], atol=1e-4, ) + + +@pytest.mark.unit +class TestTransformerBatchedInference: + @classmethod + def setup_class(cls): + cls.n_layers = 3 + cls.d_model = 4 + cls.d_ffn = 16 + cls.sa_n_heads = 2 + cls.p_dropout = 0.0 + cls.p_dropout_out = 0.0 + cls.max_length_causal_mask = 10 + cls.short_length = 4 + cls.long_length = 10 + + def test_forward(self): + set_seed(0) + query_tensor1 = torch.randn(1, self.long_length, self.d_model) + query_tensor2 = torch.randn(1, self.short_length, self.d_model) + + padding_tensor = torch.randn(1, self.long_length - self.short_length, self.d_model) + query_tensor2_padded = torch.cat([query_tensor2, padding_tensor], dim=1) + lengths = torch.tensor([self.long_length, self.short_length]) + mask_batched = get_mask_from_lengths(lengths) + + query_batched = torch.cat([query_tensor1, query_tensor2_padded], dim=0) + + mask_bs1_1 = torch.ones(1, self.long_length) + mask_bs1_2 = torch.ones(1, self.short_length) + + for is_causal in [True, False]: + for kernel_size in [1, 3]: + model = Transformer( + n_layers=self.n_layers, + d_model=self.d_model, + d_ffn=self.d_ffn, + sa_n_heads=self.sa_n_heads, + kernel_size=kernel_size, + p_dropout=self.p_dropout, + p_dropout_out=self.p_dropout_out, + is_causal=is_causal, + max_length_causal_mask=self.max_length_causal_mask, + ) + + output_batched = model(query_batched, mask_batched) + output_bs1_1 = model(query_tensor1, mask_bs1_1) + output_bs1_2 = model(query_tensor2, mask_bs1_2) + + assert torch.allclose( + output_batched['output'][0][: self.long_length, :], output_bs1_1['output'], atol=1e-4 + ) + assert torch.allclose( + output_batched['output'][1][: self.short_length, :], output_bs1_2['output'], atol=1e-4 + ) diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_DecoderContext.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_DecoderContext.sh new file mode 100644 index 000000000000..b1c7967fa273 --- /dev/null +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_DecoderContext.sh @@ -0,0 +1,33 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts.py \ + --config-name magpietts_dc_en \ + +train_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_train_context_v1.json" \ + +train_ds_meta.an4.audio_dir="/" \ + +train_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +train_ds_meta.an4.feature_dir=null \ + +val_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val_context_v1.json" \ + +val_ds_meta.an4.audio_dir="/" \ + +val_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +val_ds_meta.an4.feature_dir=null \ + max_epochs=1 \ + batch_size=4 \ + model.codecmodel_path="/home/TestData/tts/21fps_causal_codecmodel.nemo" \ + trainer.devices="[0]" \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.strategy=auto \ + model.train_ds.dataloader_params.num_workers=0 \ + model.validation_ds.dataloader_params.num_workers=0 \ + ~trainer.check_val_every_n_epoch diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder.sh new file mode 100644 index 000000000000..861ad3fdb92d --- /dev/null +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder.sh @@ -0,0 +1,33 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts.py \ + --config-name magpietts_en \ + +train_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_train_context_v1.json" \ + +train_ds_meta.an4.audio_dir="/" \ + +train_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +train_ds_meta.an4.feature_dir=null \ + +val_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val_context_v1.json" \ + +val_ds_meta.an4.audio_dir="/" \ + +val_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +val_ds_meta.an4.feature_dir=null \ + max_epochs=1 \ + batch_size=4 \ + model.codecmodel_path="/home/TestData/tts/21fps_causal_codecmodel.nemo" \ + trainer.devices="[0]" \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.strategy=auto \ + model.train_ds.dataloader_params.num_workers=0 \ + model.validation_ds.dataloader_params.num_workers=0 \ + ~trainer.check_val_every_n_epoch diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh new file mode 100644 index 000000000000..186cd2148e56 --- /dev/null +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_OnlinePO.sh @@ -0,0 +1,90 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts.py \ + --config-name magpietts_multilingual_v1 \ + +mode="onlinepo_train" \ + +model.text_tokenizers.english_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.english_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.spanish_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.spanish_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.french_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.french_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.dutch_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.dutch_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.italian_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.italian_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.german_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.german_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.portugese_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.portugese_chartokenizer.pretrained_model="google/byt5-small" \ + +model.text_tokenizers.polish_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.polish_chartokenizer.pretrained_model="google/byt5-small" \ + +train_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_train_context_v1.json" \ + +train_ds_meta.an4.audio_dir="/" \ + +train_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +train_ds_meta.an4.feature_dir=null \ + +val_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val_context_v1.json" \ + +val_ds_meta.an4.audio_dir="/" \ + +val_ds_meta.an4.tokenizer_names="[english_phoneme]" \ + +val_ds_meta.an4.feature_dir=null \ + +init_from_ptl_ckpt="/home/TestData/tts/2506_SeenSpeaker/T5TTS--val_loss\=0.3125-epoch\=8.ckpt" \ + max_epochs=1 \ + batch_size=2 \ + +model.grpo_beta=0.0 \ + +model.num_generations_per_item=6 \ + +model.reference_free=true \ + +model.inference_cfg_prob=0.0 \ + +model.inference_cfg_scale=2.5 \ + +model.cer_reward_weight=0.5 \ + +model.ssim_reward_weight=0.5 \ + +model.reward_asr_model="whisper" \ + model.local_transformer_type="none" \ + model.cfg_unconditional_prob=0.0 \ + model.model_type="multi_encoder_context_tts" \ + model.transcript_decoder_layers="[0,2,4,6,8,10]" \ + model.context_decoder_layers="[1,3,5,7,9,11]" \ + model.context_duration_min=3.0 \ + model.context_duration_max=8.0 \ + model.decoder.p_dropout=0.0 \ + model.context_encoder.p_dropout=0.0 \ + model.encoder.p_dropout=0.0 \ + model.decoder.kernel_size=1 \ + model.decoder.xa_n_heads=1 \ + model.context_encoder.n_layers=6 \ + model.encoder.is_causal=false \ + model.use_text_conditioning_encoder=true \ + +model.legacy_text_conditioning=True \ + +model.forced_num_all_tokens_per_codebook=2048 \ + +model.forced_audio_eos_id=2047 \ + +model.forced_audio_bos_id=2046 \ + +model.forced_context_audio_eos_id=2045 \ + +model.forced_context_audio_bos_id=2044 \ + model.codecmodel_path="/home/TestData/tts/AudioCodec_21Hz_no_eliz_without_wavlm_disc.nemo" \ + model.alignment_loss_scale=0.0 \ + model.prior_scaling_factor=null \ + trainer.log_every_n_steps=10 \ + +model.inference_topk=2016 \ + model.optim.lr=2e-7 \ + ~model.optim.sched \ + +model.use_kv_cache_during_online_po=true \ + exp_manager.checkpoint_callback_params.monitor="val_cer_gt" \ + exp_manager.checkpoint_callback_params.mode="min" \ + trainer.precision=32 \ + trainer.devices="[0]" \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.strategy=auto \ + model.train_ds.dataloader_params.num_workers=0 \ + model.validation_ds.dataloader_params.num_workers=0 \ + ~trainer.check_val_every_n_epoch diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh new file mode 100644 index 000000000000..d47e112ad2bb --- /dev/null +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh @@ -0,0 +1,30 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo scripts/magpietts/infer_and_evaluate.py \ + --codecmodel_path /home/TestData/tts/AudioCodec_21Hz_no_eliz_without_wavlm_disc.nemo \ + --datasets an4_val_ci \ + --out_dir ./mp_ss_0 \ + --batch_size 4 \ + --use_cfg \ + --cfg_scale 2.5 \ + --num_repeats 1 \ + --temperature 0.6 \ + --hparams_files /home/TestData/tts/2506_SeenSpeaker/hparams.yaml \ + --checkpoint_files /home/TestData/tts/2506_SeenSpeaker/T5TTS--val_loss=0.3125-epoch=8.ckpt \ + --legacy_codebooks \ + --legacy_text_conditioning \ + --apply_attention_prior \ + --clean_up_disk \ + --cer_target 0.3 \ + --ssim_target 0.5 diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh new file mode 100644 index 000000000000..228f7e97fc35 --- /dev/null +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh @@ -0,0 +1,30 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run --branch -a --data-file=/workspace/.coverage --source=/workspace/nemo scripts/magpietts/infer_and_evaluate.py \ + --codecmodel_path /home/TestData/tts/AudioCodec_21Hz_no_eliz_without_wavlm_disc.nemo \ + --datasets an4_val_ci \ + --out_dir ./mp_zs_0 \ + --batch_size 4 \ + --use_cfg \ + --cfg_scale 2.5 \ + --num_repeats 1 \ + --temperature 0.6 \ + --hparams_files /home/TestData/tts/2506_ZeroShot/lrhm_short_yt_prioralways_alignement_0.002_priorscale_0.1.yaml \ + --checkpoint_files /home/TestData/tts/2506_ZeroShot/dpo-T5TTS--val_loss=0.4513-epoch=3.ckpt \ + --legacy_codebooks \ + --legacy_text_conditioning \ + --apply_attention_prior \ + --clean_up_disk \ + --cer_target 0.1 \ + --ssim_target 0.7