Skip to content

Commit

Permalink
[ASR] Conformer global tokens in local attention (NVIDIA#6253)
Browse files Browse the repository at this point in the history
* global tokens

Signed-off-by: sam1373 <[email protected]>

* test, configs, docs

Signed-off-by: sam1373 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: sam1373 <[email protected]>

* style

Signed-off-by: sam1373 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* names, comments

Signed-off-by: sam1373 <[email protected]>

* move comment

Signed-off-by: sam1373 <[email protected]>

* import

Signed-off-by: sam1373 <[email protected]>

* docs

Signed-off-by: sam1373 <[email protected]>

* docs

Signed-off-by: sam1373 <[email protected]>

* disable note

Signed-off-by: sam1373 <[email protected]>

---------

Signed-off-by: sam1373 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
2 people authored and hsiehjackson committed Jun 2, 2023
1 parent e64fc69 commit 1e878e8
Show file tree
Hide file tree
Showing 8 changed files with 927 additions and 17 deletions.
7 changes: 7 additions & 0 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ the transducer config at ``<NeMo_git_root>/examples/asr/conf/fastconformer/fast-

Note that both configs are subword-based (BPE).

You can also train these models with longformer-style attention (https://arxiv.org/abs/2004.05150) using the following configs: CTC config at
``<NeMo_git_root>/examples/asr/conf/fastconformer/fast-conformer-long_ctc_bpe.yaml`` and transducer config at ``<NeMo_git_root>/examples/asr/conf/fastconformer/fast-conformer-long_transducer_bpe.yaml``
This allows using the model on longer audio (up to 70 minutes with Fast Conformer). Note that the Fast Conformer checkpoints
can be used with limited context attention even if trained with full context. However, if you also want to use global tokens,
which help aggregate information from outside the limited context, then training is required.


Cache-aware Streaming Conformer
-------------------------------

Expand Down
10 changes: 6 additions & 4 deletions docs/source/asr/results.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ There are two main ways of performing inference on long audio files in NeMo:
The first way is to use buffered inference, where the audio is divided into chunks to run on, and the output is merged afterwards.
The relevant scripts for this are contained in `this folder <https://github.com/NVIDIA/NeMo/blob/stable/examples/asr/asr_chunked_inference>`_.

The second way, specifically for models with the Conformer encoder, is to convert to local attention, which changes the costs to be linear.
This can be done even for models trained with full attention, though may result in lower WER in some cases. You can switch to local attention when running the
The second way, specifically for models with the Conformer/Fast Conformer encoder, is to use local attention, which changes the costs to be linear.
You can train Fast Conformer models with Longformer-style (https://arxiv.org/abs/2004.05150) local+global attention using one of the following configs: CTC config at
``<NeMo_git_root>/examples/asr/conf/fastconformer/fast-conformer-long_ctc_bpe.yaml`` and transducer config at ``<NeMo_git_root>/examples/asr/conf/fastconformer/fast-conformer-long_transducer_bpe.yaml``.
You can also convert any model trained with full context attention to local, though this may result in lower WER in some cases. You can switch to local attention when running the
`transcribe <https://github.com/NVIDIA/NeMo/blob/stable/examples/asr/transcribe_speech.py>`_ or `evaluation <https://github.com/NVIDIA/NeMo/blob/stable/examples/asr/transcribe_speech.py>`_
scripts in the following way:

Expand All @@ -91,7 +93,7 @@ scripts in the following way:
python speech_to_text_eval.py \
(...other parameters...) \
++model_change.conformer.self_attention_model="rel_pos_local_attn" \
++model_change.conformer.att_context_size=[64, 64]
++model_change.conformer.att_context_size=[128, 128]
Alternatively, you can change the attention model after loading a checkpoint:

Expand All @@ -100,7 +102,7 @@ Alternatively, you can change the attention model after loading a checkpoint:
asr_model = ASRModel.from_pretrained('stt_en_conformer_ctc_large')
asr_model.change_attention_model(
self_attention_model="rel_pos_local_attn",
att_context_size=[64, 64]
att_context_size=[128, 128]
)
Expand Down
208 changes: 208 additions & 0 deletions examples/asr/conf/fastconformer/fast-conformer-long_ctc_bpe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# It contains the default values for training a Fast Conformer-CTC ASR model, large size (~120M) with CTC loss and sub-word encoding.
# This version uses Longformer-style attention in order to handle longer audio

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.

# You may find more info about Fast Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer

# Differences from baseline config are in
# model.encoder.global_tokens
# model.encoder.global_tokens_spacing
# model.encoder.global_attn_separate

name: "FastConformer-Long-CTC-BPE"

model:
sample_rate: 16000
log_prediction: true # enables logging sample predictions in the output during training
ctc_reduction: 'mean_volume'
skip_nan_grad: false

train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: true
use_start_end_token: false
trim_silence: false
max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 0.1
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "fully_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

test_ds:
manifest_filepath: null
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

# recommend vocab size of 128 or 256 when training on ~1k hr datasets and 1k vocab size on 10+k hr datasets
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 0
pad_value: 0.0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
# you may use lower time_masks for smaller models to have a faster convergence
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 18
d_model: 512

# Sub-sampling params
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 8 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # -1 sets it to d_model
causal_downsampling: false

# Feed forward module's params
ff_expansion_factor: 4

self_attention_model: rel_pos_local_attn # longformer-style attention (sliding window + global tokens)
global_tokens: 1 # number of tokens that attend and are attended to by all tokens (put 0 to disable)
global_tokens_spacing: 1 # how far apart the global tokens are
global_attn_separate: false # whether global tokens should use separate q,k,v layers
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [128,128] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
# conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
# null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 1

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: null
num_classes: -1
vocabulary: []

# config for InterCTC loss: https://arxiv.org/abs/2102.03216
# specify loss weights and which layers to use for InterCTC
# e.g., to reproduce the paper results, set loss_weights: [0.3]
# and apply_at_layers: [8] (assuming 18 layers). Note that final
# layer loss coefficient is automatically adjusted (to 0.7 in above example)
interctc:
loss_weights: []
apply_at_layers: []

optim:
name: adamw
lr: 1e-3
# optimizer arguments
betas: [0.9, 0.98]
# less necessity for weight_decay as we already have large augmentations with SpecAug
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
# weight decay of 0.0 with lr of 2.0 also works fine
weight_decay: 1e-3

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 15000
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 1000
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
enable_progress_bar: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_wer"
mode: "min"
save_top_k: 5
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

# you need to set these two to True to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Loading

0 comments on commit 1e878e8

Please sign in to comment.