Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] Adding ssl config for fast-conformer #6672

Merged
merged 3 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions examples/asr/conf/ssl/fast-conformer/fast-conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# This config contains the default values for self-supervised pre-training of a Conformer ASR model, large size (~120M).

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of Conformer-CTC, other parameters are the same as in this config file.
# One extra layer (compared to original paper) is added to the medium and large variants to compensate for replacing the LSTM decoder with a linear one.
#
# +-------------+---------+---------+----------+------------+-----+
# | Model | d_model | n_heads | n_layers | time_masks | lr |
# +=============+=========+========+===========+============+=====+
# | Large (121M)| 512 | 8 | 17 | 10 | 2.0 |
# +---------------------------------------------------------------+
#
# If you do not want to train with AMP, you may use weight decay of 0.0 or reduce the number of time maskings to 2
# with time_width=100. It may help when you want to train for fewer epochs and need faster convergence.
# With weight_decay=0.0, learning rate may need to get reduced to 2.0.

name: "FastConformer-SSL"

model:
sample_rate: 16000

train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: false
use_start_end_token: true
trim_silence: false
max_duration: 16.7
min_duration: 8.0
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false
min_duration: 8.0


preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 16
pad_value: 0.0

spec_augment:
_target_: nemo.collections.asr.modules.MaskedPatchAugmentation
freq_masks: 3
freq_width: 20
patch_size: 48
mask_patches: 0.5

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 17
d_model: 512

# Sub-sampling params
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 8 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # -1 sets it to d_model
causal_downsampling: false

# Reduction parameters: Can be used to add another subsampling layer at a given position.
# Having a 2x reduction will speedup the training and inference speech while keeping similar WER.
# Adding it at the end will give the best WER while adding it at the beginning will give the best speedup.
reduction: null # pooling, striding, or null
reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
# conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
# null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 1

decoder_out: 256

loss_list:
contrastive:
is_active: true # indicates whether to use this loss
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
feat_in: ${model.encoder.d_model}
feat_hidden: ${model.decoder_out}
# features in hidden layer of decoder
feat_out: ${model.decoder_out}
stride_layers: 0
# if loss.combine_time_steps is less than the encoder stride, then a corresponding amount of stride_layers needs to
# be added to the decoder (here stride and combine_time_steps are both 4)
non_stride_layers: 0
loss:
_target_: nemo.collections.asr.losses.ContrastiveLoss
in_dim: ${model.preprocessor.features}
proj_dim: ${model.decoder_out}
combine_time_steps: 8 # how many spectrogram time steps are used for one target/representation for contrastive task
quantized_targets: false # should quantizer or linear layer be used
# (quantizer is required to extract pseudo-labels for other losses)
codebook_size: 300 # number of vectors in the quantization codebook per group
num_groups: 2 # number of groups in the quantizer codebook
num_negatives: 100 # number of sampled negatives for each target
sample_from_same_utterance_only: true # should negatives be sampled only from the same utterance
sample_from_non_masked: false # should negatives be sampled from non-masked steps

mlm:
is_active: false # indicates whether to use this loss
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: ${model.encoder.d_model}
num_classes: 90000
# set this to be equal to codebook_size^groups in the contrastive loss
loss:
_target_: nemo.collections.asr.losses.MLMLoss
combine_time_steps: 4
targets_from_loss: "contrastive"
# since this loss requires targets, we can either get them from a manifest or from a quantized contrastive loss
loss_alpha: 1000.
# multiplier applied to this loss relative to others
transpose_encoded: false
# transposing input may be necessary depending on which layer is used as input to decoder
start_step: 0
# determines what global step this loss starts being used at;
# this can be set to a higher number if your training is long enough,
# which may increase early training stability
output_from_layer: null
# if we wanted to use outputs from non-final encoder layer as input to this decoder,
# the layer name should be specified here


optim:
name: adamw
lr: 5.0
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

# scheduler setup
sched:
name: NoamAnnealing
d_model: ${model.encoder.d_model}
# scheduler config override
warmup_steps: 25000
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 1000
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 1.0
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
enable_progress_bar: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_loss"
mode: "min"
save_top_k: 3

# you need to set these two to True to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
3 changes: 3 additions & 0 deletions nemo/collections/asr/models/ssl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# need to be separate for moduledict

for decoder_loss_name, decoder_loss_cfg in self._cfg.loss_list.items():
if not decoder_loss_cfg.get("is_active", True): # active by default
continue

new_decoder_loss = {
'decoder': SpeechEncDecSelfSupervisedModel.from_config_dict(decoder_loss_cfg.decoder),
'loss': SpeechEncDecSelfSupervisedModel.from_config_dict(decoder_loss_cfg.loss),
Expand Down
Loading