Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention encoder-decoder models for multiple speech-to-text tasks #8242

Merged
merged 37 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c914fcd
Rebasing canary changes at current main
pzelasko Jan 17, 2024
786adbe
Move the changes from asr transformer to nlp transformer as originall…
pzelasko Jan 18, 2024
702d22f
Merge branch 'main' into canary-2
pzelasko Jan 19, 2024
41bcf87
Merge branch 'main' into canary-2
pzelasko Jan 19, 2024
d53b3c8
update eval to strip spaces before punctuations
stevehuang52 Jan 22, 2024
4c532d7
update pc strip
stevehuang52 Jan 22, 2024
d5525bf
Merge branch 'main' into canary-2
pzelasko Jan 25, 2024
e0a214d
[canary] Refactor: `PromptedAudioToTextLhotseDataset` and `EncDecMult…
pzelasko Jan 26, 2024
02551f3
fix transcribe config
stevehuang52 Jan 27, 2024
27999c9
Refactor Canary to follow schema of remaining ASR models (#8260)
titu1994 Jan 27, 2024
6b3e18d
Merge branch 'canary-2' of https://github.com/NVIDIA/NeMo into canary-2
stevehuang52 Jan 27, 2024
b0d6f19
fix transcribe, update asr evaluator
stevehuang52 Jan 29, 2024
955eb14
Extend the docs for the canary prompt_fn
pzelasko Jan 29, 2024
dd2c97f
Incorporate changes from Nithin's code review
pzelasko Jan 29, 2024
e63db16
training bug fix and adding launch script for speech_multitask (#8270)
krishnacpuvvada Jan 29, 2024
5ccd1a1
Fix: drop_last must be true in validation/test otherwise the training…
pzelasko Jan 30, 2024
ddfc788
revert to current transcribe API
stevehuang52 Jan 30, 2024
1fbbfc9
revert changes to NLP, update docs
stevehuang52 Jan 30, 2024
5e7cd02
update eval utils
stevehuang52 Jan 30, 2024
2fd5b9d
update docs
stevehuang52 Jan 30, 2024
993550d
Remove DALI; rename compute_audio_loss to compute_loss
pzelasko Jan 30, 2024
f732b0b
set default use_model_transcribe=False
stevehuang52 Jan 31, 2024
86a3f18
change os.path.dirname to pathlib
stevehuang52 Jan 31, 2024
a595213
Merge branch 'main' into canary-2
stevehuang52 Jan 31, 2024
9b7aa0f
[canary] Test for CanaryTokenizer + refactoring (#8285)
pzelasko Jan 31, 2024
ee16e5c
Update config for AED models (#8294)
titu1994 Jan 31, 2024
05157f7
Merge branch 'main' into canary-2
titu1994 Jan 31, 2024
780d107
set default calculate_wer=False in transcribe_speech.py
stevehuang52 Feb 1, 2024
1047695
Attention encoder-decoder models for multiple speech-to-text tasks
pzelasko Jan 17, 2024
5fe0daa
Apply suggestions from code review, part 1
pzelasko Feb 2, 2024
da39af6
Apply suggestions from code review, part 2
pzelasko Feb 2, 2024
2579687
Document compute_loss
pzelasko Feb 2, 2024
aacdd1e
Merge branch 'canary-2' of https://github.com/NVIDIA/NeMo into canary…
stevehuang52 Feb 2, 2024
6a2d6e2
update transcribe_speech.py
stevehuang52 Feb 2, 2024
9bcaa08
add docstring
stevehuang52 Feb 2, 2024
e1cbea3
Attention encoder-decoder models for multiple speech-to-text tasks
pzelasko Jan 17, 2024
e429954
redo changes on transcribe, Merge branch 'canary-2' of https://github…
stevehuang52 Feb 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 280 additions & 0 deletions examples/asr/conf/speech_multitask/fast-conformer_aed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# It contains the default values for training an autoregressive FastConformer-Transformer ST model with sub-word encoding.
pzelasko marked this conversation as resolved.
Show resolved Hide resolved

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file.
# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes
# It is recommended to initialize FastConformer with ASR pre-trained encoder for better accuracy and faster convergence
pzelasko marked this conversation as resolved.
Show resolved Hide resolved

name: "FastConformer-Transformer-MultiTask"

# Initialize model encoder with pre-trained ASR FastConformer encoder for faster convergence and improved accuracy
init_from_nemo_model:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary ? There are so many more modules for Multi Task models - please add remainder or remove this section

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably shouldn't mandate that. @krishnacpuvvada WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't mandate. but, we also want to highlight that large models greatly benefit from (or often require) initializing encoder at the minimum. Lets keep the argument but change ??? to None and leave a comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

model0:
path: ???
include: ["preprocessor", "encoder"]

model:
_target_: nemo.collections.asr.models.EncDecMultiTaskModel
sample_rate: 16000
label_smoothing: 0.0
context_len_for_AR_decoding: 5
nithinraok marked this conversation as resolved.
Show resolved Hide resolved
log_prediction: true # enables logging sample predictions in the output during training

# Important ! Set the prompt format to the class you need
prompt_format: ??? # Options supported: ["canary"]
nithinraok marked this conversation as resolved.
Show resolved Hide resolved

model_defaults:
asr_enc_hidden: 1024
nithinraok marked this conversation as resolved.
Show resolved Hide resolved
lm_enc_hidden: 512
dec_hidden: 1024

train_ds:
use_lhotse: true
tarred_audio_filepaths: ???
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
manifest_filepath: ???
nithinraok marked this conversation as resolved.
Show resolved Hide resolved
sample_rate: ${model.sample_rate}
shuffle: true
num_workers: 8
# To understand the settings below, please refer to Lhotse Dataloading documentation:
# https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/datasets.rst#lhotse-dataloading
# You can also check the following configuration dataclass:
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/lhotse/dataloader.py#L36
batch_size: None
batch_duration: 360
quadratic_duration: 20
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
use_bucketing: True
num_buckets: 20
bucket_buffer_size: 20000
shuffle_buffer_size: 10000

validation_ds:
use_lhotse: true
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true
use_bucketing: false

test_ds:
use_lhotse: true
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true
use_bucketing: false

# recommend small vocab size of 128 or 256 when using 4x sub-sampling
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: null # Null for aggregate tokenizers
type: agg # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) or `agg` for aggregate tokenizers
langs:
spl_tokens: # special tokens model
dir: ???
type: bpe
en: # English tokenizer (example, replace with whichever language you would like)
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
dir: ???
type: bpe

custom_tokenizer:
_target_: nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer # Can be replaced with other tokenizer for different prompt formats
tokenizers: null # Filled at runtime by all the tokenizers inside the aggregate tokenizer

# Audio Preprocessor
preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 0
pad_value: 0.0

# SpecAugment is applied either in the model or in the data layer
spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
# you may use lower time_masks for smaller models to have a faster convergence
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

# ASR Encoder
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 24
d_model: ${model.model_defaults.asr_enc_hidden}

# Sub-sampling params
subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory
subsampling_factor: 8 # must be power of 2
subsampling_conv_channels: 256 # -1 sets it to d_model
causal_downsampling: false
reduction: null
reduction_position: null
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: false # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: batch_norm
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# Optional Transformer Encoder sandwitched between ASR Encoder and Transformer Ddcoder.
# Only used if num_layers > 0
transf_encoder:
_target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder
num_layers: 0
hidden_size: ${model.model_defaults.lm_enc_hidden}
inner_size: ${multiply:${model.model_defaults.lm_enc_hidden}, 4}
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1
mask_future: False
pre_ln: True
pre_ln_final_layer_norm: True

transf_decoder:
_target_: nemo.collections.asr.modules.transformer.get_nemo_transformer
model_name: null
pretrained: false
encoder: null
pre_ln_final_layer_norm: true

config_dict:
max_sequence_length: 512
num_token_types: 0
embedding_dropout: 0.1
learn_positional_encodings: false
hidden_size: ${model.model_defaults.dec_hidden}
inner_size: ${multiply:${model.model_defaults.dec_hidden}, 4}
num_layers: 24
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1
hidden_act: relu
pre_ln: true
vocab_size: None # Will be set by the model at runtime

# Label Prediction Head (Token Classifier)
head:
_target_: nemo.collections.asr.parts.submodules.token_classifier.TokenClassifier
num_layers: 1
activation: relu
log_softmax: true
hidden_size: ${model.transf_decoder.config_dict.hidden_size}
num_classes: None # Will be set by the model at runtime
dropout: 0.0
use_transformer_init: true

# Decoding Strategy
decoding:
strategy: beam
return_best_hypothesis: true # Returns the most probably hypothesis after beam search

beam:
beam_size: 1
len_pen: 0.0
max_generation_delta: 50

# Loss Config
loss:
_target_: nemo.collections.common.losses.smoothed_cross_entropy.SmoothedCrossEntropyLoss
label_smoothing: ${model.label_smoothing}
pad_id: null

optim:
name: adamw
lr: 3e-4
# optimizer arguments
betas: [0.9, 0.98]
# less necessity for weight_decay as we already have large augmentations with SpecAug
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
# weight decay of 0.0 with lr of 2.0 also works fine
weight_decay: 1e-3

# scheduler setup
sched:
name: InverseSquareRootAnnealing
#d_model: ${model.encoder.d_model}
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
# scheduler config override
warmup_steps: 2500
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: 100000 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 100 # Interval of logging.
enable_progress_bar: True
num_sanity_val_steps: 2 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_sacreBLEU"
mode: "max"
save_top_k: 3
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to True to continue the training
resume_if_exists: true
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ model:
min_lr: 1e-6

trainer:
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
gpus: -1 # number of GPUs, -1 would use all available GPUs
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 100
max_steps: -1 # computed at runtime if not set
Expand Down
74 changes: 74 additions & 0 deletions examples/asr/speech_multitask/speech_to_text_aed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
# Training the model
```sh
python speech_to_text_aed.py \
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
model.train_ds.audio.tarred_audio_filepaths=<path to tar files with audio> \
model.train_ds.audio_manifest_filepath=<path to audio data manifest> \
model.validation_ds.manifest_filepath=<path to validation manifest> \
model.test_ds.manifest_filepath=<path to test manifest> \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets update this docstring to support Multi task models

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you be more specific? I can't see what needs changing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring has args that don't actually get used to train the model - we need a docstring with lhotse args for data and a override of using model_defaults.encoder_hidden=xyz to shoe how to override main args on this model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

model.tokenizer.dir=<path to directory of tokenizer (not full path to the vocab file!)> \
model.tokenizer.model_path=<path to speech tokenizer model> \
model.tokenizer.type=<either bpe, wpe, or yttm> \
model.prompt_format="canary" \
trainer.devices=-1 \
trainer.accelerator="ddp" \
trainer.max_steps=100000 \
+trainer.limit_train_batches=20000 \
trainer.val_check_interval=5000 \
+trainer.use_distributed_sampler=false
model.optim.name="adamw" \
model.optim.lr=0.001 \
model.optim.betas=[0.9,0.999] \
model.optim.weight_decay=0.0001 \
model.optim.sched.warmup_steps=2000
exp_manager.create_wandb_logger=True \
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \
exp_manager.wandb_logger_kwargs.project="<Name of project>"
```


"""

import pytorch_lightning as pl
from omegaconf import OmegaConf

from nemo.collections.asr.models import EncDecMultiTaskModel
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="../conf/speech_multitask/", config_name="fast-conformer_aed")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
aed_model = EncDecMultiTaskModel(cfg=cfg.model, trainer=trainer)

# Initialize the weights of the model from another model, if provided via config
aed_model.maybe_init_from_pretrained_checkpoint(cfg)
trainer.fit(aed_model)

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
if aed_model.prepare_test(trainer):
trainer.test(aed_model)


if __name__ == '__main__':
main()
Loading
Loading