Skip to content

Commit

Permalink
[canary] Refactor: PromptedAudioToTextLhotseDataset and `EncDecMult…
Browse files Browse the repository at this point in the history
…iTaskModel` (#8247)

* Create a separate CanaryDataset and use it inside `transformer_bpe_models.py`. Ditches `token_sequence_format`.

Signed-off-by: Piotr Żelasko <[email protected]>

* [canary] Refactor: move changes in transformer_bpe_models.py to Canar… (#8252)

* [canary] Refactor: move changes in transformer_bpe_models.py to CanaryModel

Signed-off-by: Piotr Żelasko <[email protected]>

* Rename `CanaryModel` to `EncDecMultiTaskModel` and remove inheritance from `EncDecTransfModelBPE`; add a separate config for this model

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>

* Rename `CanaryDataset` to `PromptedAudioToTextLhotseDataset`; add `prompt_format_fn` argument; clean-up the `_canary_prompt_format` function a bit

Signed-off-by: Piotr Żelasko <[email protected]>

* Move tokenization into `prompt_format_fn`, fix usage, add docs

Signed-off-by: Piotr Żelasko <[email protected]>

* Backward-compatible utterance validation

Signed-off-by: Piotr Żelasko <[email protected]>

* Improve type annotations

Signed-off-by: Piotr Żelasko <[email protected]>

* config and prompt_fn registration changes from review

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko authored Jan 26, 2024
1 parent d5525bf commit e0a214d
Show file tree
Hide file tree
Showing 6 changed files with 1,142 additions and 217 deletions.
234 changes: 234 additions & 0 deletions examples/asr/conf/speech_multitask/fast-conformer_aed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# It contains the default values for training an autoregressive FastConformer-Transformer ST model with sub-word encoding.

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file.
# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes
# It is recommended to initialize FastConformer with ASR pre-trained encoder for better accuracy and faster convergence

name: "FastConformer-Transformer-MultiTask"

# Initialize model encoder with pre-trained ASR FastConformer encoder for faster convergence and improved accuracy
init_from_nemo_model:
model0:
path: ???
include: ["preprocessor", "encoder"]

model:
_target_: nemo.collections.asr.models.EncDecMultiTaskModel
sample_rate: 16000
label_smoothing: 0.0
log_prediction: true # enables logging sample predictions in the output during training

train_ds:
use_lhotse: true
tarred_audio_filepaths: ???
manifest_filepath: ???
sample_rate: ${model.sample_rate}
shuffle: true
num_workers: 8
# To understand the settings below, please refer to Lhotse Dataloading documentation:
# https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/datasets.rst#lhotse-dataloading
# You can also check the following configuration dataclass:
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/lhotse/dataloader.py#L36
batch_size: None
batch_duration: 360
quadratic_duration: 20
use_bucketing: True
num_buckets: 20
bucket_buffer_size: 20000
shuffle_buffer_size: 10000

validation_ds:
use_lhotse: true
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true
use_bucketing: false
drop_last: false

test_ds:
use_lhotse: true
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true
use_bucketing: false
drop_last: false

# recommend small vocab size of 128 or 256 when using 4x sub-sampling
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 0
pad_value: 0.0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
# you may use lower time_masks for smaller models to have a faster convergence
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 24
d_model: 1024

# Sub-sampling params
subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory
subsampling_factor: 8 # must be power of 2
subsampling_conv_channels: 256 # -1 sets it to d_model
causal_downsampling: false
reduction: null
reduction_position: null
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: false # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: batch_norm
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

transf_encoder:
num_layers: 0
hidden_size: 512
inner_size: 2048
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1

transf_decoder:
library: nemo
model_name: null
pretrained: false
max_sequence_length: 512
num_token_types: 0
embedding_dropout: 0.1
learn_positional_encodings: false
hidden_size: 1024
inner_size: 4096
num_layers: 24
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1
hidden_act: relu
pre_ln: true
pre_ln_final_layer_norm: true

head:
num_layers: 1
activation: relu
log_softmax: true
dropout: 0.0
use_transformer_init: true

beam_search:
beam_size: 1
len_pen: 0.0
max_generation_delta: 50

optim:
name: adamw
lr: 3e-4
# optimizer arguments
betas: [0.9, 0.98]
# less necessity for weight_decay as we already have large augmentations with SpecAug
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
# weight decay of 0.0 with lr of 2.0 also works fine
weight_decay: 1e-3

# scheduler setup
sched:
name: InverseSquareRootAnnealing
#d_model: ${model.encoder.d_model}
# scheduler config override
warmup_steps: 2500
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: 100000 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 100 # Interval of logging.
enable_progress_bar: True
num_sanity_val_steps: 2 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_sacreBLEU"
mode: "max"
save_top_k: 3
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to True to continue the training
resume_if_exists: true
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
80 changes: 3 additions & 77 deletions nemo/collections/asr/data/audio_to_text_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Dict, Optional, Tuple

import torch.utils.data
from lhotse.cut import MixedCut, MonoCut
from lhotse.dataset import AudioSamples
from lhotse.dataset.collation import collate_vectors

Expand Down Expand Up @@ -44,91 +43,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
}

def __init__(self, tokenizer, token_sequence_format: str = None):
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = TokenizerWrapper(tokenizer)
self.load_audio = AudioSamples(fault_tolerant=True)
assert token_sequence_format is None or token_sequence_format in [
'canary'
], f"Unsupported token_sequence_format: {token_sequence_format}"
self.token_sequence_format = token_sequence_format

def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
audio, audio_lens, cuts = self.load_audio(cuts)

tokens = [self.tokenizer(c.supervisions[0].text, c.supervisions[0].language) for c in cuts]
if self.token_sequence_format == 'canary':
tokens = self._canary_format(tokens, cuts)
tokens = [torch.as_tensor(t) for t in tokens]

tokens = [torch.as_tensor(self.tokenizer(c.supervisions[0].text, c.supervisions[0].language)) for c in cuts]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)

if self.token_sequence_format == 'canary':
padding_value = self.tokenizer._tokenizer.pad_id
else:
padding_value = 0
tokens = collate_vectors(tokens, padding_value=padding_value)

tokens = collate_vectors(tokens, padding_value=0)
return audio, audio_lens, tokens, token_lens

def _canary_format(self, tokens, cuts):
"""
prepend and append control tokens to the token sequence as per canary format
Format:
sot, src_lang_id/no_speech, transcribe/translate, tgt_lang_id, text, eot
"""
canary_tokens = []
for t, c in zip(tokens, cuts):
if isinstance(c, MixedCut):
c = c._first_non_padding_cut
assert isinstance(c, MonoCut), "Expected MonoCut."

c_t = [] # canary_tokens for this cut

# bos
c_t.append(self.tokenizer._tokenizer.bos_id)

# if len(t) is 0 append no-speech token
if len(t) == 0:
c_t.append(self.tokenizer._tokenizer.nospeech_id)
else:
# src_lang_id/no_speech
src_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['source_lang'])
c_t.append(src_lang_id)

# task
task = c.custom['taskname']
if task == 'asr':
c_t.append(self.tokenizer._tokenizer.transcribe_id)
elif task == 's2t_translation':
c_t.append(self.tokenizer._tokenizer.translate_id)
else:
raise ValueError(f"Unknown task: {task}")

# tgt_lang_id
tgt_lang_id = self.tokenizer._tokenizer.to_language_id(c.custom['target_lang'])
c_t.append(tgt_lang_id)

# PnC
pnc = f"{c.custom['pnc']}".lower().strip() # to account for bool or str
if pnc in set(['yes', 'true']):
c_t.append(self.tokenizer._tokenizer.pnc_id)
elif pnc in set(['no', 'false']):
c_t.append(self.tokenizer._tokenizer.nopnc_id)
else:
raise ValueError(f"Unknown PnC: {pnc}")

# text
c_t.extend(t)

# eos
c_t.append(self.tokenizer._tokenizer.eos_id)

canary_tokens.append(c_t)

return canary_tokens


class TokenizerWrapper:
"""
Expand Down
Loading

0 comments on commit e0a214d

Please sign in to comment.