Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/cicd-main-speech.yml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,16 @@ jobs:
- runner: self-hosted-azure
script: SPEECHLM_HF_Training_SALM
timeout: 20
- runner: self-hosted-azure
script: L2_TTS_Fast_dev_runs_Magpietts_DecoderContext
- runner: self-hosted-azure
script: L2_TTS_Fast_dev_runs_Magpietts_MultiEncoder
- runner: self-hosted-azure
script: L2_TTS_Fast_dev_runs_Magpietts_OnlinePO
- runner: self-hosted-azure
script: L2_TTS_InferEvaluate_Magpietts_ZeroShot
- runner: self-hosted-azure
script: L2_TTS_InferEvaluate_Magpietts_SeenSpeakers
needs: [unit-tests]
runs-on: ${{ matrix.runner }}
name: ${{ matrix.is-optional && 'PLEASEFIXME_' || '' }}${{ matrix.script }}
Expand Down
26 changes: 26 additions & 0 deletions examples/tts/README_frame_stacking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Overview
This PR introduces frame-stacking implementation in Magpie-TTS. Frame-stacking is disabled by default. It can be enabled by setting a `frame_stacking_factor` > 1 in the YAML config.

# Frame-stacking

## Overview
Frame-stacking is a technique that allows the Magpie-TTS **base decoder** (also known as the "main" for "first stage" decoder) to **process multiple consecutive audio frames in a single forward pass**, leaving the job of generating individual frames and codebooks to a second, smaller, "Local Transformer" ("LT") decoder. The goal is to accelerate inference by reducing the number of generation steps of the base decoder. In this two-stage approach:

1. The base decoder processes multiple frames at once, producing a single latent representation for each group (stack) of frames
2. The Local Transformer then generates the individual `frames * codebooks` tokens.

The Local Transformer is much faster than the base decoder, making this two-stage approach significantly faster than generating each frame with the base decoder. The speed improvement comes from two factors:
* **Fewer parameters**: The LT decoder is lightweight compared to the base decoder
* **Shorter sequences**: The LT decoder only attends to the current frame stack and the latent, not the entire frame sequence

The base decoder can also generate audio codes directly without a LT, but when frame-stacking is enabled using the LT decoder is typically necessary to achieve high-quality synthesis.

## Design and Implementation
* The `frame_stacking_factor` is the parameter that controls the number of frames to stack. The default is 1, which means no frame-stacking. We have tested values up to `4`.
* For each codebooks, we keep a separate embedding table for at each frame within the stack. At the input to the decoder, the embeddings are averages across codebooks (as usual) and also frames within the stack. The embedding tables are shared between the base and LT decoders.

## Limitations
This is still WIP with more work to be done. Specifically, the following are not yet implemented / tested:
* Online code extraction combined with frame-stacking.
* Alignment encoder with frame-stacking.
* CTC loss with frame-stacking.
95 changes: 95 additions & 0 deletions examples/tts/README_magpietts_legacy_checkpoints.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Background
Magpie-TTS uses special tokens like AUDIO_BOS and AUDIO_EOS for its operation. The indices of these tokens are after the audio codec tokens, at the end of the embedding table.

In April 2025 we changed the layout of the embedding table in a non-backwards compatible way:

## Old Layout (until April 16)
With the most common codec configuration (2016 codes), the layout used to look like this:
```
| Index | Token Description | Comments |
|---------|----------------------|-----------------------------------------------------------------------------------------------------------|
| [0] | Codec Token 0 | |
| [1] | Codec Token 1 | |
| [2] | Codec Token 2 | |
| ... | ... | |
| [2015] | Codec Token 2015 | |
| [2016] | <Unused> | |
| [2017] | <Unused> | |
| [2018] | <Unused> | |
| ... | | |
| [2044] | Context Audio BOS | if model_type == `decoder_context_tts` |
| [2045] | Context Audio EOS | if model_type == `decoder_context_tts` |
| [2046] | Audio BOS | also used for Context Audio BOS if model_type == `multi_encoder_context_tts` or `single_encoder_sv_tts` |
| [2047] | Audio EOS | also used for Context Audio EOS if model_type == `multi_encoder_context_tts` or `single_encoder_sv_tts` |
```

## New Layout```
The new layout for the same codec configuration is:
```
| Index | Token Description | Comments |
---------------------------------------------|
| [0] | Codec Token 0 | |
| [1] | Codec Token 1 | |
| [2] | Codec Token 2 | |
| ... | ... | |
| [2015] | Codec Token 2015 | |
| [2016] | Audio BOS | |
| [2017] | Audio EOS | |
| [2018] | Context Audio BOS | |
| [2019] | Context Audio EOS | |
| [2020] | MASK token (MaskGit) | |
| [2021] | RESERVED_1 | |
| [2022] | RESERVED_2 | |
| [2023] | RESERVED_3 | |
```

# How to Train and Load a New Checkpoint
For new trainings and inference all configuration is automatic:
* The number of codebooks, codec codebooks size, and codec downsampling rate are all read from the codec checkpoint rather than configured in Magpie.
* The embedding table size is automatically set to codec_codebook_size + number_of_special_tokens (currently 2016+8=2024). There is no risk of accidentally stepping on codec tokens since the table sizes gets automatically sized with enough room for the special tokens.

# How to Load Old Checkpoints
For checkpoints created before the change you can force legacy codebook layout in one of these ways:

## If using `infer_and_evaluate.py`
Just set the `--legacy_codebooks` command line option. No need to update your YAML file – The script will automatically add the overrides.

## If using a Hydra command line
This scenario would happen when either finetuning with an old checkpoint or doing data generation with an old checkpoint.

You have two options:
### Add these to your command line
```
# decoder context model
+model.forced_num_all_tokens_per_codebook=2048 +model.forced_audio_eos_id=2047 +model.forced_audio_bos_id=2046 +model.forced_context_audio_eos_id=2045 +model.forced_context_audio_bos_id=2044

# multi encoder context and any other model type
+model.forced_num_all_tokens_per_codebook=2048 +model.forced_audio_eos_id=2047 +model.forced_audio_bos_id=2046 +model.forced_context_audio_eos_id=2047 +model.forced_context_audio_bos_id=2046
```
# Or, add these overrides to your YAML file
```
forced_num_all_tokens_per_codebook: 2048
forced_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -1} # 2047
forced_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -2} # 2046

# Depending on the old model type, the context_audio_bos_id and context_audio_eos_id will be different (choose one of the pairs below)

# For `multi_encoder_context_tts`, `single_encoder_sv_tts`:
#forced_context_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -1} # 2047
#forced_context_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -2} # 2046

# For `decoder_context_tts` models:
#forced_context_audio_eos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -3} # 2045
#forced_context_audio_bos_id: ${sum:${model.forced_num_all_tokens_per_codebook}, -4} # 2044
```

# Additional Details
Over the last few weeks we have gone through a few embedding table layouts. When using an old checkpoint it's important to know which layout your checkpoint was trained with and configuring the system accordingly.

* Layout 1: used until April 16 (described in the table above). Add `--legacy-codebooks` to the `infer_and_evaluate.py` command line to inference using this layout.

* Layout 2: after the [config changes](https://github.com/blisc/NeMo/commit/7e2cdca74a866ecefdbe01c0076ad9b5d140ac61): 2018 tokens with special tokens at the end 2017, 2016, 2015, 2014 (the last two being overwrites of codec tokens). This is an invalid layout and these checkpoints should not be used.

* Layout 3: after the [bugfix](https://github.com/blisc/NeMo/commit/23e299a0bd14b666543b4bbcc7783f783acb0bd3) but before the [refactoring](https://github.com/blisc/NeMo/commit/8ba55061a0ebb161abff4b329e402d5307f4af98): 2024 tokens with special tokens at the end (2023, 2022, 2021, 2020). There are no automatic options provided for using this layout but it can be manually configured by updating the `hparams.yaml` file with the `forced_*` options. Set `forced_num_all_tokens_per_codebook` to `2024` and set the rest of the overrides as defined under section `# Or, add these overrides to your YAML file` above.

* Layout 4: The new layout, [from this commit onwards](https://github.com/blisc/NeMo/commit/8ba55061a0ebb161abff4b329e402d5307f4af98): 2024 tokens but with special tokens immediately after codec tokens (2016, 2017, 2018, 2019). Training and inference with the latest version of the code automatically use this layout.
199 changes: 199 additions & 0 deletions examples/tts/conf/magpietts/magpietts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
name: Magpie-TTS

max_epochs: ???
# Adjust batch size based on GPU memory
batch_size: 16
# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch.
# If null, then weighted sampling is disabled.
weighted_sampling_steps_per_epoch: null

# Dataset metadata for each manifest
# See DatasetMeta in https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/tts/data/text_to_speech_dataset.py
train_ds_meta: ???
val_ds_meta: ???

model:
model_type: "decoder_ce" # decoder_context_tts or decoder_ce
use_text_conditioning_encoder: true # Enable or disable text context. Audio context is always enabled
text_conditioning_tokenizer_name: text_ce_tokenizer # The tokenizer to be used for text contexts
context_duration_min: 5.0
context_duration_max: 5.0
load_cached_codes_if_available: true
prior_scaling_factor: 0.5
prior_end_step: 12000
prior_scaledown_start_step: 8000
indefinite_prior_prob: 0.0 # If > 0, then prior will be applied after prior_end_step with this probability.
alignment_loss_scale: 0.002
embedding_dim: 768
codecmodel_path: ???
max_epochs: ${max_epochs}
steps_per_epoch: ${weighted_sampling_steps_per_epoch}
cfg_unconditional_prob: 0.1

# Alignment encoder parameters, to binarize the prior
# This is used for attention-constrained training and inference
use_alignment_encoder: false
# Below args are only relevant if use_alignment_encoder is true
use_prior_for_aligner: true # Whether to use the beta-binomial prior to train the alignment encoder
alignment_encoder_loss_scale: 1.0
binarize_prior_after_step: 10000 # Switch from beta-binomial prior to binarized prior after this step.
binarize_attn_method: "nemo_binarize" # nemo_binarize or argmax.
prior_future_context: 2 # Future window of the binarized prior.
prior_past_context: 2 # Past window of the binarized prior.
prior_future_decay: 0.8 # Decay factor for future context
prior_past_decay: 0.5 # Decay factor for past context
binarize_repeat_audio_factor: 2 # Temporally increase audio timesteps, for nemo_binarize to work better. Increase this for low frame rate codecs
binarized_prior_epsilon: 0.0
aligner_encoder_train_steps: 50000

# Local transformer parameters for autoregressive codebook prediction within a frame
local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit"
# Below args are only relevant if use_local_transformer is true
local_transformer_loss_scale: 1.0
local_transformer_n_layers: 1
local_transformer_n_heads: 1
local_transformer_hidden_dim: 256

text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts.
text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context.

text_tokenizers:
english_phoneme:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
punct: true
apostrophe: true
pad_with_space: false
g2p:
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt"
heteronyms: "scripts/tts_dataset_files/heteronyms-052722"
phoneme_probability: 0.8
ignore_ambiguous_words: false
use_chars: true
use_stresses: true
text_ce_tokenizer: # Used for text context
_target_: AutoTokenizer
pretrained_model: "google/byt5-small"
### For additional languages, consider adding a generic byt5 tokenizer like the one below
# french_chartokenizer: # Used for text context
# _target_: AutoTokenizer
# pretrained_model: "google/byt5-small"

train_ds:
dataset:
_target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset
dataset_meta: ${train_ds_meta}
weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch}
min_duration: 0.2
max_duration: 20.0

dataloader_params:
batch_size: ${batch_size}
num_workers: 4
drop_last: true
pin_memory: true

validation_ds:
dataset:
_target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset
dataset_meta: ${val_ds_meta}
min_duration: 0.2
max_duration: 20.0

dataloader_params:
batch_size: ${batch_size}
num_workers: 4
pin_memory: true

encoder:
n_layers: 6
d_model: 768
d_ffn: 3072
sa_n_heads: 12
kernel_size: 3
p_dropout: 0.1
p_dropout_out: 0.0
has_xattn: false
is_causal: true
apply_norm_out: true
max_length_causal_mask: 2048
use_learnable_pos_emb: true

context_encoder: # Only used for decoder_ce (and multi_encoder_context_tts), ignored otherwise
n_layers: 1
d_model: 768
d_ffn: 3072
sa_n_heads: 12
kernel_size: 3
p_dropout: 0.1
p_dropout_out: 0.0
has_xattn: false
is_causal: false
apply_norm_out: true
max_length_causal_mask: 2048
use_learnable_pos_emb: true

decoder:
n_layers: 12
d_model: 768
d_ffn: 3072
sa_n_heads: 12
kernel_size: 1
p_dropout: 0.1
p_dropout_out: 0.0
has_xattn: true
xa_d_head: 128
xa_d_memory: 768
xa_n_heads: 1
is_causal: true
apply_norm_to_cond: true
apply_norm_out: true
max_length_causal_mask: 2048
use_learnable_pos_emb: true
make_prior_window_strict: true

optim:
_target_: torch.optim.AdamW
lr: 2e-4

sched:
name: ExponentialLR
gamma: 0.998

trainer:
num_nodes: 1
devices: -1
accelerator: gpu
strategy: ddp_find_unused_parameters_true
precision: 32
max_epochs: ${max_epochs}
accumulate_grad_batches: 1
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
benchmark: false
gradient_clip_val: 2.5

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_wandb_logger: false
wandb_logger_kwargs:
entity: null
project: null
group: null
name: ${name}
resume: true # enable resume to ensure continuous training log metrics merged on the previous run id.
create_checkpoint_callback: true
checkpoint_callback_params:
monitor: val_loss
mode: min
save_top_k: 5
save_best_model: true
always_save_nemo: true
filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}'
resume_if_exists: true
resume_ignore_no_checkpoint: true
Loading
Loading