Skip to content

Commit

Permalink
fast conformer configs and doc (#5970) (#5997)
Browse files Browse the repository at this point in the history
* fast conformer configs and doc



* feedback



* adding fast conformer to main README



* path changes



* rewording



* further doc changes



* naming

---------

Signed-off-by: Dima Rekesh <[email protected]>
Co-authored-by: Dima Rekesh <[email protected]>
Co-authored-by: Dima Rekesh <[email protected]>
  • Loading branch information
3 people authored Feb 12, 2023
1 parent 349b095 commit 8267971
Show file tree
Hide file tree
Showing 4 changed files with 460 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Key Features
* Speech processing
* `HuggingFace Space for Audio Transcription (File, Microphone and YouTube) <https://huggingface.co/spaces/smajumdar/nemo_multilingual_language_id>`_
* `Automatic Speech Recognition (ASR) <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/intro.html>`_
* Supported models: Jasper, QuartzNet, CitriNet, Conformer-CTC, Conformer-Transducer, Squeezeformer-CTC, Squeezeformer-Transducer, ContextNet, LSTM-Transducer (RNNT), LSTM-CTC, ...
* Supported models: Jasper, QuartzNet, CitriNet, Conformer-CTC, Conformer-Transducer, Squeezeformer-CTC, Squeezeformer-Transducer, ContextNet, LSTM-Transducer (RNNT), LSTM-CTC, FastConformer-CTC, FastConformer-Transducer...
* Supports CTC and Transducer/RNNT losses/decoders
* NeMo Original `Multi-blank Transducers <https://arxiv.org/abs/2211.03541>`_
* Beam Search decoding
Expand Down
20 changes: 20 additions & 0 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,26 @@ You may find the example config files of Conformer-Transducer model with charact
``<NeMo_git_root>/examples/asr/conf/conformer/conformer_transducer_char.yaml`` and
with sub-word encoding at ``<NeMo_git_root>/examples/asr/conf/conformer/conformer_transducer_bpe.yaml``.

Fast-Conformer
--------------

The Fast Conformer (CTC and RNNT) models have a faster version of the Conformer encoder and differ from it as follows:

* 8x depthwise convolutional subsampling with 256 channels
* Reduced convolutional kernel size of 9 in the conformer blocks

The Fast Conformer encoder is about 2.4x faster than the regular Conformer encoder without a significant model quality degradation.
128 subsampling channels yield a 2.7x speedup vs baseline but model quality starts to degrade.
With local attention, inference is possible on audios >1 hrs (256 subsampling channels) / >2 hrs (128 channels).

Fast Conformer models were trained using CosineAnnealing (instead of Noam) as the scheduler.

You may find the example CTC config at
``<NeMo_git_root>/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml`` and
the transducer config at ``<NeMo_git_root>/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml``

Note that both configs are subword-based (BPE).

Cache-aware Streaming Conformer
-------------------------------

Expand Down
187 changes: 187 additions & 0 deletions examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# It contains the default values for training a Fast Conformer-CTC ASR model, large size (~120M) with CTC loss and sub-word encoding.

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.

# You may find more info about Fast Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer

name: "FastConformer-CTC-BPE"

model:
sample_rate: 16000
log_prediction: true # enables logging sample predictions in the output during training
ctc_reduction: 'mean_volume'
skip_nan_grad: false

train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: true
use_start_end_token: false
trim_silence: false
max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 0.1
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "fully_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

test_ds:
manifest_filepath: null
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

# recommend vocab size of 128 or 256 when training on ~1k hr datasets and 1k vocab size on 10+k hr datasets
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 0
pad_value: 0.0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
# you may use lower time_masks for smaller models to have a faster convergence
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 18
d_model: 512

# Sub-sampling params
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 8 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # -1 sets it to d_model
causal_downsampling: false

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
# conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
# null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: null
num_classes: -1
vocabulary: []

optim:
name: adamw
lr: 1e-3
# optimizer arguments
betas: [0.9, 0.98]
# less necessity for weight_decay as we already have large augmentations with SpecAug
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
# weight decay of 0.0 with lr of 2.0 also works fine
weight_decay: 1e-3

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 15000
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 1000
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
enable_progress_bar: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_wer"
mode: "min"
save_top_k: 5
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

# you need to set these two to True to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Loading

0 comments on commit 8267971

Please sign in to comment.