Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add conformer configs for hat model #6372

Merged
merged 14 commits into from
Apr 14, 2023
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Key Features
* Speech processing
* `HuggingFace Space for Audio Transcription (File, Microphone and YouTube) <https://huggingface.co/spaces/smajumdar/nemo_multilingual_language_id>`_
* `Automatic Speech Recognition (ASR) <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/intro.html>`_
* Supported models: Jasper, QuartzNet, CitriNet, Conformer-CTC, Conformer-Transducer, Squeezeformer-CTC, Squeezeformer-Transducer, ContextNet, LSTM-Transducer (RNNT), LSTM-CTC, FastConformer-CTC, FastConformer-Transducer...
* Supported models: Jasper, QuartzNet, CitriNet, Conformer-CTC, Conformer-Transducer, Squeezeformer-CTC, Squeezeformer-Transducer, ContextNet, LSTM-Transducer (RNNT), LSTM-CTC, FastConformer-CTC, FastConformer-Transducer, Conformer-HAT...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please also update the following statement to have Hybrid ASR:
Supports CTC, Transducer/RNNT and Hybrid losses/decoders

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

* Supports CTC and Transducer/RNNT losses/decoders
* NeMo Original `Multi-blank Transducers <https://arxiv.org/abs/2211.03541>`_
* Beam Search decoding
Expand Down
25 changes: 25 additions & 0 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,31 @@ You may find the example config files of Conformer variant of such hybrid models
with sub-word encoding at ``<NeMo_git_root>/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_bpe.yaml``.


.. _Conformer-HAT_model:

Conformer-HAT (Hybrid Autoregressive Transducer)
--------------------------------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lines should have the same size as the title.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Conformer HAT model (do not confuse it with Hybrid-Transducer-CTC) is a modification of Conformer-Transducer model based on `Google paper <https://arxiv.org/abs/2003.07705>`_.
The main idea is to separate labels and blank score predictions, which allows to estimate the internal LM probabilities during decoding.
When external LM is available for inference, the internal LM can be subtracted from HAT model prediction in beamsearch decoding to improve external LM efficiency.
It can be helpful in the case of text-only adaptation for new domains.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can users use this feature?
Do the current LM scripts support it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default Conformer HAT model works in decoding time as a standard Transducer model with the same interface. However, if you have an external ngram LM you can use scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py script. The new updated version of the script is under reviewing -- #6370

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VahidooX -- could you approve the PR if everything is OK?


The only difference from the standard Conformer-Transducer model (RNNT) is the use of `"HATJiont" <https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/modules/hybrid_autoregressive_transducer.py#L39>`_
class (instead of "RNNTJoint") for joint module. The all HAT logic is implemented in the "HATJiont" class.

.. image:: images/hat.png
:align: center
:alt: HAT Model
:scale: 50%

You may find the example config files of Conformer-HAT model with character-based encoding at
``<NeMo_git_root>/examples/asr/conf/conformer/hat/conformer_hat_char.yaml`` and
with sub-word encoding at ``<NeMo_git_root>/examples/asr/conf/conformer/hat/conformer_hat_bpe.yaml``.

By default, the decoding for HAT model works in the same way as for Conformer-Transducer.
In the case of external ngram LM fusion you can use ``<NeMo_git_root>/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py``.
To enable HAT internal LM subtraction set ``hat_subtract_ilm=True`` and find more appropriate couple of ``beam_alpha`` and ``hat_ilm_weight`` values in terms of the best recognition accuracy.

References
----------

Expand Down
267 changes: 267 additions & 0 deletions examples/asr/conf/conformer/hat/conformer_hat_bpe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# It contains the default values for training a Conformer-HAT (Hybrid Autoregressive Transducer - https://arxiv.org/abs/2003.07705) ASR model,
# large size (~120M) with Transducer loss and sub-word encoding.
# The only difference from the standard Conformer-Transducer model (RNNT) is the use of "HATJiont" class (instead of "RNNTJoint") for joint module.

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of Conformer-HAT, other parameters are the same as in this config file.
#
# +--------------+---------+---------+----------+------------------+--------------+--------------------------+-----------------+
# | Model | d_model | n_heads | n_layers | conv_kernel_size | weight_decay | pred_hidden/joint_hidden | pred_rnn_layers |
# +==============+=========+========+===========+==================+==============+==========================+=================+
# | Small (14M)| 176 | 4 | 16 | 31 | 0.0 | 320 | 1 |
# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+
# | Medium (32M)| 256 | 4 | 16 | 31 | 1e-3 | 640 | 1 |
# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+
# | Large (120M)| 512 | 8 | 17 | 31 | 1e-3 | 640 | 1 |
# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+
# | XLarge (644M)| 1024 | 8 | 24 | 5 | 1e-3 | 640 | 2 |
# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+
#
# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-hat-hybrid-autoregressive-transducer

name: "Conformer-HAT-BPE"

model:
sample_rate: 16000
compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
log_prediction: true # enables logging sample predictions in the output during training
skip_nan_grad: false

model_defaults:
enc_hidden: ${model.encoder.d_model}
pred_hidden: 640
joint_hidden: 640

train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: true
use_start_end_token: false
trim_silence: false
max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 0.1
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

test_ds:
manifest_filepath: null
sample_rate: ${model.sample_rate}
batch_size: 16
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

# You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
frame_splicing: 1
dither: 0.00001
pad_to: 0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 17
d_model: 512

# Sub-sampling parameters
subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 4 # must be power of 2 for striding and vggnet
subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model
causal_downsampling: false

# Reduction parameters: Can be used to add another subsampling layer at a given position.
# Having a 2x reduction will speedup the training and inference speech while keeping similar WER.
# Adding it at the end will give the best WER while adding it at the beginning will give the best speedup.
reduction: null # pooling, striding, or null
reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 31
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
# conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
# null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 1

decoder:
_target_: nemo.collections.asr.modules.RNNTDecoder
normalization_mode: null # Currently only null is supported for export.
random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf
blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference.

prednet:
pred_hidden: ${model.model_defaults.pred_hidden}
pred_rnn_layers: 1
t_max: null
dropout: 0.2

joint:
_target_: nemo.collections.asr.modules.HATJoint # the only difference from the standard RNNT model
log_softmax: null # 'null' would set it automatically according to CPU/GPU device
preserve_memory: false # dramatically slows down training, but might preserve some memory

# Fuses the computation of prediction net + joint net + loss + WER calculation
# to be run on sub-batches of size `fused_batch_size`.
# When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size.
# `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss.
# Using small values here will preserve a lot of memory during training, but will make training slower as well.
# An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1.
# However, to preserve memory, this ratio can be 1:8 or even 1:16.
# Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow.
fuse_loss_wer: true
fused_batch_size: 16

jointnet:
joint_hidden: ${model.model_defaults.joint_hidden}
activation: "relu"
dropout: 0.2

decoding:
strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd.

# greedy strategy config
greedy:
max_symbols: 10

# beam strategy config
beam:
beam_size: 2
return_best_hypothesis: False
score_norm: true
tsd_max_sym_exp: 50 # for Time Synchronous Decoding
alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding

loss:
loss_name: "default"

warprnnt_numba_kwargs:
# FastEmit regularization: https://arxiv.org/abs/2010.11148
# You may enable FastEmit to reduce the latency of the model for streaming
fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start.
clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.

# Adds Gaussian noise to the gradients of the decoder to avoid overfitting
variational_noise:
start_step: 0
std: 0.0

optim:
name: adamw
lr: 5.0
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

# scheduler setup
sched:
name: NoamAnnealing
d_model: ${model.encoder.d_model}
# scheduler config override
warmup_steps: 10000
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 500
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
enable_progress_bar: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training


exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_wer"
mode: "min"
save_top_k: 5
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints
resume_if_exists: false
resume_ignore_no_checkpoint: false

create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Loading