Skip to content

Commit 0095519

Browse files
titu1994pzelasko
authored andcommitted
Attention encoder-decoder models for multiple speech-to-text tasks (NVIDIA#8242) (NVIDIA#8324)
* Rebasing canary changes at current main Signed-off-by: Piotr Żelasko <[email protected]> * Move the changes from asr transformer to nlp transformer as originally intended Signed-off-by: Piotr Żelasko <[email protected]> * update eval to strip spaces before punctuations Signed-off-by: stevehuang52 <[email protected]> * update pc strip Signed-off-by: stevehuang52 <[email protected]> * [canary] Refactor: `PromptedAudioToTextLhotseDataset` and `EncDecMultiTaskModel` (NVIDIA#8247) * Create a separate CanaryDataset and use it inside `transformer_bpe_models.py`. Ditches `token_sequence_format`. Signed-off-by: Piotr Żelasko <[email protected]> * [canary] Refactor: move changes in transformer_bpe_models.py to Canar… (NVIDIA#8252) * [canary] Refactor: move changes in transformer_bpe_models.py to CanaryModel Signed-off-by: Piotr Żelasko <[email protected]> * Rename `CanaryModel` to `EncDecMultiTaskModel` and remove inheritance from `EncDecTransfModelBPE`; add a separate config for this model Signed-off-by: Piotr Żelasko <[email protected]> --------- Signed-off-by: Piotr Żelasko <[email protected]> * Rename `CanaryDataset` to `PromptedAudioToTextLhotseDataset`; add `prompt_format_fn` argument; clean-up the `_canary_prompt_format` function a bit Signed-off-by: Piotr Żelasko <[email protected]> * Move tokenization into `prompt_format_fn`, fix usage, add docs Signed-off-by: Piotr Żelasko <[email protected]> * Backward-compatible utterance validation Signed-off-by: Piotr Żelasko <[email protected]> * Improve type annotations Signed-off-by: Piotr Żelasko <[email protected]> * config and prompt_fn registration changes from review Signed-off-by: Piotr Żelasko <[email protected]> --------- Signed-off-by: Piotr Żelasko <[email protected]> * fix transcribe config Signed-off-by: stevehuang52 <[email protected]> * Refactor Canary to follow schema of remaining ASR models (NVIDIA#8260) * Initial draft of multi task beam decoding strategy Signed-off-by: smajumdar <[email protected]> * Stabilize inference Signed-off-by: smajumdar <[email protected]> * Update AED Multi Task model to mostly conform to Archetype-Type format. Update config Signed-off-by: smajumdar <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add change decoding strategy Signed-off-by: smajumdar <[email protected]> * Remove redundant imports Signed-off-by: smajumdar <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup Signed-off-by: smajumdar <[email protected]> * Cleanup Signed-off-by: smajumdar <[email protected]> * remove asr transformer dependency on nlp Signed-off-by: stevehuang52 <[email protected]> * clean up Signed-off-by: stevehuang52 <[email protected]> * copy token_classifier from nlp to asr Signed-off-by: stevehuang52 <[email protected]> * Address comments Signed-off-by: smajumdar <[email protected]> * Add typing to beam decoding Signed-off-by: smajumdar <[email protected]> * Make prompt format configurable Signed-off-by: smajumdar <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * drop asr dependency on nlp Signed-off-by: stevehuang52 <[email protected]> --------- Signed-off-by: smajumdar <[email protected]> Signed-off-by: stevehuang52 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: stevehuang52 <[email protected]> * fix transcribe, update asr evaluator Signed-off-by: stevehuang52 <[email protected]> * Extend the docs for the canary prompt_fn Signed-off-by: Piotr Żelasko <[email protected]> * Incorporate changes from Nithin's code review Signed-off-by: Piotr Żelasko <[email protected]> * training bug fix and adding launch script for speech_multitask (NVIDIA#8270) * bug fix and adding launch script for speech_multitask Signed-off-by: Krishna Puvvada <[email protected]> * update launch script example in speech_to_text_aed.py Signed-off-by: Krishna Puvvada <[email protected]> --------- Signed-off-by: Krishna Puvvada <[email protected]> Co-authored-by: Krishna Puvvada <[email protected]> * Fix: drop_last must be true in validation/test otherwise the training will hang Signed-off-by: Piotr Żelasko <[email protected]> * revert to current transcribe API Signed-off-by: stevehuang52 <[email protected]> * revert changes to NLP, update docs Signed-off-by: stevehuang52 <[email protected]> * update eval utils Signed-off-by: stevehuang52 <[email protected]> * update docs Signed-off-by: stevehuang52 <[email protected]> * Remove DALI; rename compute_audio_loss to compute_loss Signed-off-by: Piotr Żelasko <[email protected]> * set default use_model_transcribe=False Signed-off-by: stevehuang52 <[email protected]> * change os.path.dirname to pathlib Signed-off-by: stevehuang52 <[email protected]> * [canary] Test for CanaryTokenizer + refactoring (NVIDIA#8285) * Test for CanaryTokenizer Signed-off-by: Piotr Żelasko <[email protected]> * Attempt at refactor... Signed-off-by: Piotr Żelasko <[email protected]> --------- Signed-off-by: Piotr Żelasko <[email protected]> * Update config for AED models (NVIDIA#8294) Signed-off-by: smajumdar <[email protected]> * set default calculate_wer=False in transcribe_speech.py Signed-off-by: stevehuang52 <[email protected]> * Attention encoder-decoder models for multiple speech-to-text tasks Signed-off-by: Piotr Żelasko <[email protected]> * Apply suggestions from code review, part 1 Co-authored-by: Nithin Rao <[email protected]> Signed-off-by: Piotr Żelasko <[email protected]> * Apply suggestions from code review, part 2 Signed-off-by: Piotr Żelasko <[email protected]> * Document compute_loss Signed-off-by: Piotr Żelasko <[email protected]> * update transcribe_speech.py Signed-off-by: stevehuang52 <[email protected]> * add docstring Signed-off-by: stevehuang52 <[email protected]> * Attention encoder-decoder models for multiple speech-to-text tasks Signed-off-by: Piotr Żelasko <[email protected]> --------- Signed-off-by: Piotr Żelasko <[email protected]> Signed-off-by: stevehuang52 <[email protected]> Signed-off-by: smajumdar <[email protected]> Signed-off-by: Krishna Puvvada <[email protected]> Signed-off-by: Piotr Żelasko <[email protected]> Co-authored-by: stevehuang52 <[email protected]> Co-authored-by: Somshubra Majumdar <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Krishna Puvvada <[email protected]> Co-authored-by: Krishna Puvvada <[email protected]> Co-authored-by: He Huang (Steve) <[email protected]> Co-authored-by: Nithin Rao <[email protected]> (cherry picked from commit d10726d) Co-authored-by: Piotr Żelasko <[email protected]> Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent b07eee8 commit 0095519

30 files changed

+2643
-119
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# It contains the default values for training an autoregressive FastConformer-Transformer AED model with sub-word encoding.
2+
3+
# Architecture and training config:
4+
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
5+
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
6+
# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file.
7+
# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes
8+
# It is recommended to initialize FastConformer with ASR/SSL pre-trained encoder for better accuracy and faster convergence
9+
10+
name: "FastConformer-Transformer-MultiTask"
11+
12+
# Note: for larger models (1B+ params) initializing from a pretrained encoder
13+
# may help (or even be required to) stabilize the training.
14+
init_from_nemo_model: null
15+
16+
model:
17+
_target_: nemo.collections.asr.models.EncDecMultiTaskModel
18+
sample_rate: 16000
19+
label_smoothing: 0.0
20+
context_len_for_AR_decoding: 5 # Length of input prompt tokens. For example, in Canary models, we use [BOS,src_lang,task,tgt_lang,pnc] and thus the length is 5
21+
log_prediction: true # enables logging sample predictions in the output during training
22+
23+
# Important ! Set the prompt format to the class you need
24+
prompt_format: ??? # Options supported: ["canary"]
25+
26+
model_defaults:
27+
asr_enc_hidden: 1024
28+
lm_enc_hidden: 512
29+
lm_dec_hidden: 1024
30+
31+
train_ds:
32+
use_lhotse: true
33+
tarred_audio_filepaths: null
34+
manifest_filepath: ???
35+
sample_rate: ${model.sample_rate}
36+
shuffle: true
37+
num_workers: 8
38+
# To understand the settings below, please refer to Lhotse Dataloading documentation:
39+
# https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/datasets.rst#lhotse-dataloading
40+
# You can also check the following configuration dataclass:
41+
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/lhotse/dataloader.py#L36
42+
batch_size: None
43+
batch_duration: 360
44+
quadratic_duration: 15
45+
use_bucketing: True
46+
num_buckets: 20
47+
bucket_buffer_size: 20000
48+
shuffle_buffer_size: 10000
49+
50+
validation_ds:
51+
use_lhotse: true
52+
manifest_filepath: ???
53+
sample_rate: ${model.sample_rate}
54+
batch_size: 8 # you may increase batch_size if your memory allows
55+
shuffle: false
56+
num_workers: 4
57+
pin_memory: true
58+
use_start_end_token: true
59+
use_bucketing: false
60+
61+
test_ds:
62+
use_lhotse: true
63+
manifest_filepath: ???
64+
sample_rate: ${model.sample_rate}
65+
batch_size: 8 # you may increase batch_size if your memory allows
66+
shuffle: false
67+
num_workers: 4
68+
pin_memory: true
69+
use_start_end_token: true
70+
use_bucketing: false
71+
72+
# recommend small vocab size of 128 or 256 when using 4x sub-sampling
73+
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
74+
tokenizer:
75+
dir: null # Null for aggregate tokenizers
76+
type: agg # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) or `agg` for aggregate tokenizers
77+
langs:
78+
spl_tokens: # special tokens model
79+
dir: ???
80+
type: bpe
81+
en: # English tokenizer (example, replace with whichever language you would like or add tokenizers to add tokenizer for additional languages)
82+
dir: ???
83+
type: bpe
84+
85+
custom_tokenizer:
86+
_target_: nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer # Can be replaced with other tokenizer for different prompt formats
87+
tokenizers: null # Filled at runtime by all the tokenizers inside the aggregate tokenizer
88+
89+
# Audio Preprocessor
90+
preprocessor:
91+
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
92+
sample_rate: ${model.sample_rate}
93+
normalize: "per_feature"
94+
window_size: 0.025
95+
window_stride: 0.01
96+
window: "hann"
97+
features: 80
98+
n_fft: 512
99+
log: true
100+
frame_splicing: 1
101+
dither: 0.00001
102+
pad_to: 0
103+
pad_value: 0.0
104+
105+
# SpecAugment is applied either in the model or in the data layer
106+
spec_augment:
107+
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
108+
freq_masks: 2 # set to zero to disable it
109+
# you may use lower time_masks for smaller models to have a faster convergence
110+
time_masks: 10 # set to zero to disable it
111+
freq_width: 27
112+
time_width: 0.05
113+
114+
# FastConformer Encoder
115+
encoder:
116+
_target_: nemo.collections.asr.modules.ConformerEncoder
117+
feat_in: ${model.preprocessor.features}
118+
feat_out: -1 # you may set it if you need different output size other than the default d_model
119+
n_layers: 24
120+
d_model: ${model.model_defaults.asr_enc_hidden}
121+
122+
# Sub-sampling params
123+
subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory
124+
subsampling_factor: 8 # must be power of 2
125+
subsampling_conv_channels: 256 # -1 sets it to d_model
126+
causal_downsampling: false
127+
reduction: null
128+
reduction_position: null
129+
reduction_factor: 1
130+
131+
# Feed forward module's params
132+
ff_expansion_factor: 4
133+
134+
# Multi-headed Attention Module's params
135+
self_attention_model: rel_pos # rel_pos or abs_pos
136+
n_heads: 8 # may need to be lower for smaller d_models
137+
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
138+
att_context_size: [-1, -1] # -1 means unlimited context
139+
xscaling: false # scales up the input embeddings by sqrt(d_model)
140+
untie_biases: true # unties the biases of the TransformerXL layers
141+
pos_emb_max_len: 5000
142+
143+
# Convolution module's params
144+
conv_kernel_size: 9
145+
conv_norm_type: batch_norm
146+
conv_context_size: null
147+
148+
### regularization
149+
dropout: 0.1 # The dropout used in most of the Conformer Modules
150+
dropout_pre_encoder: 0.1
151+
dropout_emb: 0.0 # The dropout used for embeddings
152+
dropout_att: 0.1 # The dropout for multi-headed attention modules
153+
154+
# Optional Transformer Encoder sandwitched between ASR Encoder and Transformer Ddcoder.
155+
# Only used if num_layers > 0
156+
transf_encoder:
157+
_target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder
158+
num_layers: 0
159+
hidden_size: ${model.model_defaults.lm_enc_hidden}
160+
inner_size: ${multiply:${model.model_defaults.lm_enc_hidden}, 4}
161+
num_attention_heads: 8
162+
ffn_dropout: 0.1
163+
attn_score_dropout: 0.1
164+
attn_layer_dropout: 0.1
165+
mask_future: False
166+
pre_ln: True
167+
pre_ln_final_layer_norm: True
168+
169+
transf_decoder:
170+
_target_: nemo.collections.asr.modules.transformer.get_nemo_transformer
171+
model_name: null
172+
pretrained: false
173+
encoder: null
174+
pre_ln_final_layer_norm: true
175+
176+
config_dict:
177+
max_sequence_length: 512
178+
num_token_types: 0
179+
embedding_dropout: 0.1
180+
learn_positional_encodings: false
181+
hidden_size: ${model.model_defaults.lm_dec_hidden}
182+
inner_size: ${multiply:${model.model_defaults.lm_dec_hidden}, 4}
183+
num_layers: 24
184+
num_attention_heads: 8
185+
ffn_dropout: 0.1
186+
attn_score_dropout: 0.1
187+
attn_layer_dropout: 0.1
188+
hidden_act: relu
189+
pre_ln: true
190+
vocab_size: None # Will be set by the model at runtime
191+
192+
# Label Prediction Head (Token Classifier)
193+
head:
194+
_target_: nemo.collections.asr.parts.submodules.token_classifier.TokenClassifier
195+
num_layers: 1
196+
activation: relu
197+
log_softmax: true
198+
hidden_size: ${model.transf_decoder.config_dict.hidden_size}
199+
num_classes: None # Will be set by the model at runtime
200+
dropout: 0.0
201+
use_transformer_init: true
202+
203+
# Decoding Strategy
204+
decoding:
205+
strategy: beam
206+
return_best_hypothesis: true # Returns the most probably hypothesis after beam search
207+
208+
beam:
209+
beam_size: 1
210+
len_pen: 0.0
211+
max_generation_delta: 50
212+
213+
# Loss Config
214+
loss:
215+
_target_: nemo.collections.common.losses.smoothed_cross_entropy.SmoothedCrossEntropyLoss
216+
label_smoothing: ${model.label_smoothing}
217+
pad_id: null
218+
219+
optim:
220+
name: adamw
221+
lr: 3e-4
222+
# optimizer arguments
223+
betas: [0.9, 0.98]
224+
# less necessity for weight_decay as we already have large augmentations with SpecAug
225+
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
226+
# weight decay of 0.0 with lr of 2.0 also works fine
227+
weight_decay: 1e-3
228+
229+
# scheduler setup
230+
sched:
231+
name: InverseSquareRootAnnealing
232+
# scheduler config override
233+
warmup_steps: 2500
234+
warmup_ratio: null
235+
min_lr: 1e-6
236+
237+
trainer:
238+
devices: -1 # number of GPUs, -1 would use all available GPUs
239+
num_nodes: 1
240+
max_epochs: -1
241+
max_steps: 100000 # computed at runtime if not set
242+
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
243+
accelerator: auto
244+
strategy: ddp
245+
accumulate_grad_batches: 1
246+
gradient_clip_val: 0.0
247+
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP.
248+
log_every_n_steps: 100 # Interval of logging.
249+
enable_progress_bar: True
250+
num_sanity_val_steps: 2 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
251+
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
252+
sync_batchnorm: true
253+
enable_checkpointing: False # Provided by exp_manager
254+
logger: false # Provided by exp_manager
255+
256+
exp_manager:
257+
exp_dir: null
258+
name: ${name}
259+
create_tensorboard_logger: true
260+
create_checkpoint_callback: true
261+
checkpoint_callback_params:
262+
# in case of multiple validation sets, first one is used
263+
monitor: "val_sacreBLEU"
264+
mode: "max"
265+
save_top_k: 3
266+
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints
267+
268+
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
269+
# you need to set these two to True to continue the training
270+
resume_if_exists: true
271+
resume_ignore_no_checkpoint: false
272+
273+
# You may use this section to create a W&B logger
274+
create_wandb_logger: false
275+
wandb_logger_kwargs:
276+
name: null
277+
project: null

examples/asr/conf/speech_translation/fast-conformer_transformer.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ model:
176176
min_lr: 1e-6
177177

178178
trainer:
179-
gpus: -1 # number of GPUs, -1 would use all available GPUs
179+
devices: -1 # number of GPUs, -1 would use all available GPUs
180180
num_nodes: 1
181181
max_epochs: 100
182182
max_steps: -1 # computed at runtime if not set
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
# Training the model
17+
```sh
18+
python speech_to_text_aed.py \
19+
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
20+
model.train_ds.tarred_audio_filepaths=<path to tar files with audio> \
21+
model.train_ds.manifest_filepath=<path to audio data manifest> \
22+
model.train_ds.batch_duration=360 \
23+
model.train_ds.num_buckets=30 \
24+
model.train_ds.bucket_duration_bins=<optional list of precomputed float bins for bucket durations, speeds up init> \
25+
model.validation_ds.manifest_filepath=<path to validation manifest> \
26+
model.test_ds.manifest_filepath=<path to test manifest> \
27+
model.model_defaults.asr_enc_hidden=1024 \
28+
model.model_defaults.lm_enc_hidden=512 \
29+
model.model_defaults.lm_dec_hidden=1024 \
30+
model.tokenizer.langs.spl_tokens.dir=<path to the directory of prompt special tokens tokenizer> \
31+
model.tokenizer.langs.spl_tokens.type=bpe \
32+
model.tokenizer.langs.en.dir=<path to the directory of en language tokenizer (add new langs the same way)> \
33+
model.tokenizer.langs.en.type=bpe \
34+
model.prompt_format="canary" \
35+
trainer.devices=-1 \
36+
trainer.accelerator="ddp" \
37+
trainer.max_steps=100000 \
38+
+trainer.limit_train_batches=20000 \
39+
trainer.val_check_interval=5000 \
40+
+trainer.use_distributed_sampler=false \
41+
model.optim.name="adamw" \
42+
model.optim.lr=0.001 \
43+
model.optim.betas=[0.9,0.999] \
44+
model.optim.weight_decay=0.0001 \
45+
model.optim.sched.warmup_steps=2000 \
46+
exp_manager.create_wandb_logger=True \
47+
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \
48+
exp_manager.wandb_logger_kwargs.project="<Name of project>"
49+
```
50+
51+
52+
"""
53+
54+
import pytorch_lightning as pl
55+
from omegaconf import OmegaConf
56+
57+
from nemo.collections.asr.models import EncDecMultiTaskModel
58+
from nemo.core.config import hydra_runner
59+
from nemo.utils import logging
60+
from nemo.utils.exp_manager import exp_manager
61+
62+
63+
@hydra_runner(config_path="../conf/speech_multitask/", config_name="fast-conformer_aed")
64+
def main(cfg):
65+
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
66+
67+
trainer = pl.Trainer(**cfg.trainer)
68+
exp_manager(trainer, cfg.get("exp_manager", None))
69+
aed_model = EncDecMultiTaskModel(cfg=cfg.model, trainer=trainer)
70+
71+
# Initialize the weights of the model from another model, if provided via config
72+
aed_model.maybe_init_from_pretrained_checkpoint(cfg)
73+
trainer.fit(aed_model)
74+
75+
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
76+
if aed_model.prepare_test(trainer):
77+
trainer.test(aed_model)
78+
79+
80+
if __name__ == '__main__':
81+
main()

0 commit comments

Comments
 (0)