Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

st standalone model #6969

Merged
merged 23 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ed71fbf
st standalone model
AlexGrinch Jul 3, 2023
842f2e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2023
10019f8
style fix
AlexGrinch Jul 4, 2023
24f7204
Merge branch 'speechtrans' of https://github.com/NVIDIA/NeMo into spe…
AlexGrinch Jul 4, 2023
83d81c5
Merge branch 'main' into speechtrans
AlexGrinch Jul 4, 2023
6615136
sacrebleu import fix, unused imports removed
AlexGrinch Jul 5, 2023
ced657b
import guard for nlp inside asr transformer bpe model
AlexGrinch Jul 5, 2023
e91980a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2023
14318a0
Merge branch 'main' into speechtrans
AlexGrinch Jul 5, 2023
1a13c14
codeql fixes
AlexGrinch Jul 6, 2023
b4289f5
merge head
AlexGrinch Jul 6, 2023
67a3d96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2023
98bb3ec
Merge branch 'main' into speechtrans
AlexGrinch Jul 11, 2023
0696263
comments answered
AlexGrinch Jul 12, 2023
249f312
import ordering fix
AlexGrinch Jul 12, 2023
c2c00d4
yttm for asr removed
AlexGrinch Jul 13, 2023
7eacacb
Merge branch 'main' into speechtrans
AlexGrinch Jul 13, 2023
66d428f
logging added
AlexGrinch Jul 13, 2023
bc40262
Merge branch 'main' into speechtrans
AlexGrinch Jul 14, 2023
5fed6fd
added inference and translate method
AlexGrinch Jul 14, 2023
66fdfcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2023
e097a44
Merge branch 'main' into speechtrans
AlexGrinch Jul 17, 2023
db4762b
Merge branch 'main' into speechtrans
AlexGrinch Jul 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions examples/asr/conf/speech_translation/fast-conformer_transformer.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

speech_translation/fast_conformer/fastconformer_transformer.yaml

Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# It contains the default values for training an autoregressive FastConformer-Transformer ST model with sub-word encoding.

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file.
# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes
# It is recommended to initialize FastConformer with ASR pre-trained encoder for better accuracy and faster convergence

name: "FastConformer-Transformer-BPE-st"

# Initialize model encoder with pre-trained ASR FastConformer encoder for faster convergence and improved accuracy
init_from_nemo_model:
model0:
path: ???
include: ["preprocessor", "encoder"]

model:
sample_rate: 16000
label_smoothing: 0.0
log_prediction: true # enables logging sample predictions in the output during training

train_ds:
is_tarred: true
tarred_audio_filepaths: ???
manifest_filepath: ???
sample_rate: 16000
shuffle: false
trim_silence: false
batch_size: 4
num_workers: 8

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true

test_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true

# recommend small vocab size of 128 or 256 when using 4x sub-sampling
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 0
pad_value: 0.0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
# you may use lower time_masks for smaller models to have a faster convergence
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 17
d_model: 512

# Sub-sampling params
subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory
subsampling_factor: 8 # must be power of 2
subsampling_conv_channels: 256 # -1 sets it to d_model
causal_downsampling: false
reduction: null
reduction_position: null
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: batch_norm
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

transf_encoder:
num_layers: 0
hidden_size: 512
inner_size: 2048
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1

transf_decoder:
library: nemo
model_name: null
pretrained: false
max_sequence_length: 512
num_token_types: 0
embedding_dropout: 0.1
learn_positional_encodings: false
hidden_size: 512
inner_size: 2048
num_layers: 6
num_attention_heads: 4
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1
hidden_act: relu
pre_ln: true
pre_ln_final_layer_norm: true

head:
num_layers: 1
activation: relu
log_softmax: true
dropout: 0.0
use_transformer_init: true

beam_search:
beam_size: 4
len_pen: 0.0
max_generation_delta: 50

optim:
name: adam
lr: 0.0001
# optimizer arguments
betas: [0.9, 0.98]
# less necessity for weight_decay as we already have large augmentations with SpecAug
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
# weight decay of 0.0 with lr of 2.0 also works fine
#weight_decay: 1e-3

# scheduler setup
sched:
name: InverseSquareRootAnnealing
#d_model: ${model.encoder.d_model}
# scheduler config override
warmup_steps: 1000
warmup_ratio: null
min_lr: 1e-6

trainer:
gpus: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 100
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 100 # Interval of logging.
enable_progress_bar: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_sacreBLEU"
mode: "max"
save_top_k: 3
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

# you need to set these two to True to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
70 changes: 70 additions & 0 deletions examples/asr/speech_translation/speech_to_text_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
# Training the model
```sh
python speech_to_text_transformer.py \
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
model.train_ds.audio.tarred_audio_filepaths=<path to tar files with audio> \
model.train_ds.audio_manifest_filepath=<path to audio data manifest> \
model.validation_ds.manifest_filepath=<path to validation manifest> \
model.test_ds.manifest_filepath=<path to test manifest> \
model.tokenizer.dir=<path to directory of tokenizer (not full path to the vocab file!)> \
model.tokenizer.model_path=<path to speech tokenizer model> \
model.tokenizer.type=<either bpe, wpe, or yttm> \
trainer.gpus=-1 \
trainer.accelerator="ddp" \
trainer.max_epochs=100 \
model.optim.name="adamw" \
model.optim.lr=0.001 \
model.optim.betas=[0.9,0.999] \
model.optim.weight_decay=0.0001 \
model.optim.sched.warmup_steps=2000
exp_manager.create_wandb_logger=True \
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \
exp_manager.wandb_logger_kwargs.project="<Name of project>"
```


"""

import pytorch_lightning as pl
from omegaconf import OmegaConf

from nemo.collections.asr.models import EncDecTransfModelBPE
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="../conf/speech_translation/", config_name="fast-conformer_transformer")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
asr_model = EncDecTransfModelBPE(cfg=cfg.model, trainer=trainer)

# Initialize the weights of the model from another model, if provided via config
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
trainer.fit(asr_model)

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
if asr_model.prepare_test(trainer):
trainer.test(asr_model)


if __name__ == '__main__':
main()
Loading
Loading