Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TDT model pull request #6536

Merged
merged 59 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
a6269d7
TDT model pull request, initial draft
May 2, 2023
952d261
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
a6893a6
TDT PR WIP
May 4, 2023
9107407
TDT PR WIP
May 4, 2023
657ed9d
TDT PR WIP
May 4, 2023
8b01b42
TDT PR WIP
May 4, 2023
51eed21
TDT WIP
May 5, 2023
452f45b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 5, 2023
e3947e1
TDT WIP
May 5, 2023
c4764ea
Merge branch 'TDT_PR_2' of https://github.com/hainan-xv/NeMo into TDT…
May 5, 2023
fd9c8ac
TDT WIP
May 5, 2023
23c5759
TDT WIP
May 5, 2023
c5032b9
TDT WIP
May 5, 2023
4053e3d
TDT WIP
May 5, 2023
23251fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 5, 2023
06175ae
TDT WIP
May 5, 2023
79e61b7
Merge branch 'TDT_PR_2' of https://github.com/hainan-xv/NeMo into TDT…
May 5, 2023
26d6307
TDT WIP
May 5, 2023
656ea9e
TDT WIP
May 5, 2023
6d7b172
Merge branch 'main' of https://github.com/NVIDIA/NeMo into TDT_PR_2
May 5, 2023
ffb9502
Merge branch 'main' of https://github.com/NVIDIA/NeMo into TDT_PR_2
May 9, 2023
1e55f92
TDT WIP
May 9, 2023
31100de
TDT WIP
May 10, 2023
59a2a52
addressed some review comments, part1
May 11, 2023
fe52c95
addressed some review comments, part1, one line fix
May 11, 2023
4b00365
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2023
462f69e
add tests for comparing TDT alphas with pytorch VS kernel computation
May 17, 2023
69f9466
Merge branch 'TDT_PR_2' of https://github.com/hainan-xv/NeMo into TDT…
May 17, 2023
45e025c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2023
7d29d44
add tests for comparing multiblank alphas with pytorch VS kernel comp…
May 17, 2023
aff4318
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2023
cc52b56
Merge branch 'main' of https://github.com/NVIDIA/NeMo into TDT_PR_2
May 17, 2023
2483b52
Merge branch 'TDT_PR_2' of https://github.com/hainan-xv/NeMo into TDT…
May 17, 2023
68ca68c
add tests for fixed case computation for TDT
May 17, 2023
55bfcee
Merge branch 'main' of https://github.com/NVIDIA/NeMo into TDT_PR_2
May 17, 2023
8fe52cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2023
7d41d36
add more comments for greedy-batch decoding for TDT
May 17, 2023
92424b2
add more comments for greedy-batch decoding for TDT
May 17, 2023
fc0d3c3
include config for TDT model with stateless decoders
May 17, 2023
709ab0e
add reference to TDT in Readme
May 17, 2023
c2b23c0
slight modification of config file comments
May 19, 2023
64c3c72
Merge branch 'main' of https://github.com/NVIDIA/NeMo into TDT_PR_2
May 19, 2023
bd1388c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into TDT_PR_2
May 23, 2023
dcf4cbb
addressed more comments
May 23, 2023
ab47b88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2023
72dd625
Merge branch 'main' of https://github.com/NVIDIA/NeMo into TDT_PR_2
May 25, 2023
ccfcbc8
Merge branch 'main' of https://github.com/NVIDIA/NeMo into TDT_PR_2
May 25, 2023
d7b7307
more detailed comments for tdt kernel
May 25, 2023
0cfcd06
Merge branch 'TDT_PR_2' of https://github.com/hainan-xv/NeMo into TDT…
May 25, 2023
9167032
one line fix
Jun 1, 2023
bf7b234
merge with head
Jun 1, 2023
5dd1115
Merge branch 'main' into TDT_PR_2
hainan-xv Jun 2, 2023
fb5f73e
fixed small bug that results in test fails for rnnt_decoding
Jun 2, 2023
3d81ff1
fixed small bug that results in test fails for rnnt_decoding
Jun 2, 2023
4eb8bfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
9b8f057
fixed small bug that results in test fails for rnnt_decoding
Jun 2, 2023
caef0c5
Merge branch 'TDT_PR_2' of https://github.com/hainan-xv/NeMo into TDT…
Jun 2, 2023
3a5db33
Merge branch 'main' into TDT_PR_2
hainan-xv Jun 2, 2023
c85d143
remove unused import
Jun 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Key Features
* CTC
* Transducer/RNNT
* Hybrid Transducer/CTC
* NeMo Original `Multi-blank Transducers <https://arxiv.org/abs/2211.03541>`_
* NeMo Original `Multi-blank Transducers <https://arxiv.org/abs/2211.03541>`_ and `Token-and-Duration Transducers (TDT) <https://arxiv.org/abs/2304.06795>`_
* Streaming/Buffered ASR (CTC/Transducer) - `Chunked Inference Examples <https://github.com/NVIDIA/NeMo/tree/stable/examples/asr/asr_chunked_inference>`_
* Cache-aware Streaming Conformer - `<https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer>`_
* Beam Search decoding
Expand Down
279 changes: 279 additions & 0 deletions examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# This file contains the default values for training a Conformer-TDT ASR model, large size (~120M) with sub-word encoding.

# You can find detailed info about TDT models at https://arxiv.org/abs/2304.06795.
titu1994 marked this conversation as resolved.
Show resolved Hide resolved

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file.

# Note: the added duration outputs from the joiner make TDT models slightly larger than corresponding conventional RNN-T models,
# although the difference is tiny -- the added number of params is roughly num-durations X (joint_hidden + pred_hidden), typically in the
# order of thousands of params. This is negligible even with the "Small" config with around 14 million params.
# Recommended duraction config is [0, 1, 2, ... , n] where optimal n is usually between 4 and 8 depending on the dataset.

# +--------------+---------+---------+----------+------------------+--------------+--------------------------+-----------------+
# | Model | d_model | n_heads | n_layers | conv_kernel_size | weight_decay | pred_hidden/joint_hidden | pred_rnn_layers |
# +==============+=========+========+===========+==================+==============+==========================+=================+
# | Small (14M)| 176 | 4 | 16 | 31 | 0.0 | 320 | 1 |
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+
# | Medium (32M)| 256 | 4 | 16 | 31 | 1e-3 | 640 | 1 |
# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+
# | Large (120M)| 512 | 8 | 17 | 31 | 1e-3 | 640 | 1 |
# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+
# | XLarge (644M)| 1024 | 8 | 24 | 5 | 1e-3 | 640 | 2 |
# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+

# Default learning parameters in this config are set for global batch size of 2K while you may use lower values.
# To increase the global batch size with limited number of GPUs, you may use higher accumulate_grad_batches.
# However accumulate_grad_batches is better to be avoided as long as the global batch size is large enough and training is stable.

name: "Conformer-TDT-BPE"

model:
sample_rate: 16000
compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
log_prediction: true # enables logging sample predictions in the output during training
skip_nan_grad: false

model_defaults:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
enc_hidden: ${model.encoder.d_model}
pred_hidden: 640
joint_hidden: 640

# variables for TDT configs.
tdt_durations: [0, 1, 2, 3, 4]
num_tdt_durations: 5


train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: true
use_start_end_token: false
trim_silence: false
max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 0.1
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

test_ds:
manifest_filepath: null
sample_rate: ${model.sample_rate}
batch_size: 16
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

# You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
frame_splicing: 1
dither: 0.00001
pad_to: 0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 17
d_model: 512

# Sub-sampling params
subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 4 # must be power of 2 for striding and vggnet
subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model
causal_downsampling: false

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 31
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
# conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
# null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

decoder:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
_target_: nemo.collections.asr.modules.RNNTDecoder
normalization_mode: null # Currently only null is supported for export.
random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf
blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference.

prednet:
pred_hidden: ${model.model_defaults.pred_hidden}
pred_rnn_layers: 1
t_max: null
dropout: 0.2

joint:
_target_: nemo.collections.asr.modules.RNNTJoint
log_softmax: null # 'null' would set it automatically according to CPU/GPU device
preserve_memory: false # dramatically slows down training, but might preserve some memory

# Fuses the computation of prediction net + joint net + loss + WER calculation
# to be run on sub-batches of size `fused_batch_size`.
# When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size.
# `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss.
# Using small values here will preserve a lot of memory during training, but will make training slower as well.
# An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1.
# However, to preserve memory, this ratio can be 1:8 or even 1:16.
# Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow.
fuse_loss_wer: true
fused_batch_size: 16

jointnet:
joint_hidden: ${model.model_defaults.joint_hidden}
activation: "relu"
dropout: 0.2
num_extra_outputs: ${model.model_defaults.num_tdt_durations}

decoding:
# Using greedy decoding is highly recommended for TDT models. Using greedy-batch will give very bad results
# if omega is 0; even if omega is non-zero, greedy-batch results are still going to be inaccurate.
strategy: "greedy"

# this must not be None in order to use the TDT specific decoding method.
durations: ${model.model_defaults.tdt_durations}

# greedy strategy config
greedy:
max_symbols: 10

# beam strategy config
beam:
beam_size: 2
return_best_hypothesis: False
score_norm: true
tsd_max_sym_exp: 50 # for Time Synchronous Decoding
alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding

loss:
# This is the main different between a TDT model and a conventional RNNT model -- the loss function.
loss_name: "tdt"

tdt_kwargs:
# FastEmit regularization: https://arxiv.org/abs/2010.11148
# You may enable FastEmit to reduce the latency of the model for streaming
fastemit_lambda: 0.001 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start.
clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.

# refer to https://arxiv.org/abs/2304.06795 for the meaning of the following three configs.
durations: ${model.model_defaults.tdt_durations}
sigma: 0.05 # hyper-param for under-normalization.
omega: 0.1 # weight for regular RNN-T loss.

# Adds Gaussian noise to the gradients of the decoder to avoid overfitting
variational_noise:
start_step: 0
std: 0.0

optim:
name: adamw
lr: 5.0
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

# scheduler setup
sched:
name: NoamAnnealing
d_model: ${model.encoder.d_model}
# scheduler config override
warmup_steps: 10000
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 500
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
enable_progress_bar: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training


exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_wer"
mode: "min"
save_top_k: 5
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints
resume_if_exists: false
resume_ignore_no_checkpoint: false

create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null

Loading
Loading