Skip to content

Commit

Permalink
Modular SpeechLLM implementation for Sept. 2023 submission (SALM) (#7634
Browse files Browse the repository at this point in the history
)

* add initial impl of ModularizedSpeechGPTModel and integration test

* fix typo in the test name (#1)

approve the nit change

* clean a initial version of example config; make sure it works by test (#2)

approve as no need to review

* add the test for training_step and fix the code correspondingly (test passed now) (#3)

* add test for validation_step (#4)

* mv audio and text emb concat to prepare_llm_input so as to write test to guard the llm input

* Merge heh and zhehuai's initial version of frozen am+llm (#5)

* Merge heh and zhehuai's initial version of frozen am+llm

The previous differences are summarized here:
https://docs.google.com/document/d/1zNI4hC6vJtUfcHbrUSPaMuYWRBQdN_36H0P2NiBiuPY/edit

This PR includes
1. Finish merging the model, dataset, and config code
2. Previous tests are still enabled and passed (prepare_llm_input, training_step,
    validation_step)
3. the example training script with LS960 has been run to make sure the training
pipeline works

The major remaining works are listed here
https://docs.google.com/document/d/1o0AM7v4gcTQkPZjE0Vl9TTX4vYnGTrbXEFGWh0UhGlk/edit#bookmark=id.pzvdadt5oxyw

---------

Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>

* fix a nit init bug broke test (#6)

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* Clean up implementation for SALM paper and sync to NEMO v1.20.0 (#18)

* wip

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* fix data

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* fix consumed_samples

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* fix the training restart problem by storing adapter+perception model and
init them from the ckpt

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* refix state dict

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* support wer and inf

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* nan guard

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* reimpl inf and bug fix

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* multi loader

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* unfreeze lm

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* flag for load am

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* tokenizer

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* overwrite vocab size

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* support bpe dropout

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* add tarred datasets

Signed-off-by: stevehuang52 <heh@nvidia.com>

* fix sample_alpha

Signed-off-by: stevehuang52 <heh@nvidia.com>

* fix bpe dropout bugs in the mismatched context in tokenization

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* add bleu metric

Signed-off-by: stevehuang52 <heh@nvidia.com>

* update metrics

Signed-off-by: stevehuang52 <heh@nvidia.com>

* support inference and fix a bug in wer calculation

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* fix bucketing dataset

Signed-off-by: stevehuang52 <heh@nvidia.com>

* fix bleu implementation

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* support question set file per dataset/data loader in preparation for
multitask understanding; also fix bleu implementation

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* support simple random context for word boosting

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* use sacrebleu.corpus_bleu to be consistent with the rest

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* make audio_file optional in the data loader

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* add a tool to materialize mt and text data

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* compatible with tar dataset

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* temp fix for metric and speed up materialization

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* make num of context configurable

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* val_check_interval fix; make manifest dumping consistent with speech models

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* random_context_positive_ratio configurable to control precision

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* bug fix: freeze_llm flag is not passed to the model cfg

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* overwrite tensor_model_parallel_size

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* support both stt and ssl models for loading audio encoder

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* fix the inference config so as to use sampling; allow inference config update in training

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* refactorize and clean up code for preprocessing collections, dataset interface, model inference and rename some classes to be consistent with salm paper.
also make sure test passed

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* Undo changes in megatron_gpt_peft_models.py and move them to speechllm_models.py; make sure the correctness by test_speechllm_models.py::TestModularizedAudioGPTModel::test_predict_step

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* update default inference config and test golden value accordingly

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* integration test and minor fix

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* nit bug fix on manifest_filepath introduced by code cleanup

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* update workspace/ files; consider moving to examples later

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* further remove unnecessary stuff in the inference implementation

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* revert the update in default end_string to be compatible with legacy models

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

---------

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: stevehuang52 <heh@nvidia.com>
Co-authored-by: stevehuang52 <heh@nvidia.com>
Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>

* rename 'ModularizedAudioGPTModel' to 'ModularAudioGPTLoRAModel'; move speechllm stuff under nemo/collections/multimodal/speechllm

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* update copyright; remove workspace/scripts and workspace/tools folders since the main branch has LLaMA support

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

---------

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: stevehuang52 <heh@nvidia.com>
Co-authored-by: Zhehuai Chen <chenzhehuai.sjtu@aispeech.com>
Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>
Co-authored-by: stevehuang52 <heh@nvidia.com>
  • Loading branch information
4 people authored Oct 9, 2023
1 parent 2baef81 commit 215cb9d
Showing 27 changed files with 4,682 additions and 7 deletions.
320 changes: 320 additions & 0 deletions examples/multimodel/conf/speechllm/modularized_speech_gpt_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: megatron_audio_gpt_peft_tuning

trainer:
devices: 1
accelerator: gpu
num_nodes: 1
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
max_epochs: 9999
max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10 # frequency with which training steps are logged
val_check_interval: 1.0 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
gradient_clip_val: 1.0
accumulate_grad_batches: 1

exp_manager:
# explicit_log_dir: null
exp_dir: null
name: ${name}
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: validation_${model.data.validation_ds.metric.name}
save_top_k: 1
mode: min
save_nemo_on_train_end: True
filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}'
model_parallel_size: ${model.tensor_model_parallel_size}
always_save_nemo: False
save_best_model: True
create_early_stopping_callback: False
early_stopping_callback_params:
monitor: "val_loss"
mode: "min"
min_delta: 0.001
patience: 10
verbose: True
strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.


model:
seed: 1234
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism

pretrained_audio_model: stt_en_fastconformer_transducer_large
freeze_llm: True
freeze_audio_encoder: False
freeze_modality_adapter: False
load_audio_encoder: True

global_batch_size: 128
micro_batch_size: 4
restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training.
sync_batch_comm: False
megatron_amp_O2: False

## Sequence Parallelism
# Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
sequence_parallel: False

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null # not used with 'selective'
activations_checkpoint_layers_per_pipeline: null
answer_only_loss: True
gradient_as_bucket_view: False

hidden_dropout: 0.0
attention_dropout: 0.0
ffn_dropout: 0.0

# use_am_tokenizer: True
# override_vocab_size: 1024

peft:
peft_scheme: "adapter" # can be either adapter,ia3, or ptuning
restore_from_path: null

# Used for adapter peft training
adapter_tuning:
type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter'
adapter_dim: 32
adapter_dropout: 0.0
norm_position: 'pre' # This can be set to 'pre' or 'post', 'pre' is normally what is used.
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm']

lora_tuning:
adapter_dim: 32
adapter_dropout: 0.0
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal

# Used for p-tuning peft training
p_tuning:
virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence
bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck
embedding_dim: 1024 # the size of the prompt encoder embeddings
init_std: 0.023

perception:
modality_adapter:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: 1024
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 2
d_model: 512

# Sub-sampling parameters
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 8 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model
causal_downsampling: false

# Reduction parameters: Can be used to add another subsampling layer at a given position.
# Having a 2x reduction will speedup the training and inference speech while keeping similar WER.
# Adding it at the end will give the best WER while adding it at the beginning will give the best speedup.
reduction: null # pooling, striding, or null
reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
# conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
# null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 1

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

# the following are read from the pretrained AM:
# output_dim: null
# encoder: null
# preprocessor: null

data:
end_string: "~"
train_ds:
# Example of how to specify paths to multiple datasets
# manifest_filepath:
# - /path/to/squad.jsonl
# - /path/to/mnli.jsonl
# - /path/to/boolq.jsonl
# Example of how each dataset is formatted
# {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'question': 'transcribe this audio', 'answer': 'I have a dream...'}
# the 'answer' field can also be 'text', and a default 'question' field is added if missing in manigests, so as to work with ASR manifests
manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
num_workers: 0
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: True
# Notably, the data weights are controlled by either bucketing_weights
# or concat_sampling_probabilities depending on the dataset type (tar and
# non-tar).
# See audio_text_qa_dataset.py for details.
concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random'
context_key: 'input'
label_key: 'output'
# add_eos: True
add_eos: False
end_string: ${model.data.end_string}
add_sep: False
add_bos: False
separate_prompt_and_response_with_newline: False
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: "Q: {input}\nA: {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
# ASR configs
sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate}
max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 0.1
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "fully_randomized"
bucketing_batch_size: null
# sample_alpha: 0.1

validation_ds:
manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: False
num_workers: 0
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: False
context_key: ${model.data.train_ds.context_key}
label_key: ${model.data.train_ds.label_key}
add_eos: ${model.data.train_ds.add_eos}
end_string: ${model.data.end_string}
add_sep: ${model.data.train_ds.add_sep}
add_bos: ${model.data.train_ds.add_bos}
separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline}
write_predictions_to_file: False
output_file_path_prefix: null # Prefix of the file to write predictions to.
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
tokens_to_generate: 128
# ASR configs
sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate}

log_every_n_steps: 1
metric:
name: "wer" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null

# test_ds:
# manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
# names: null # Names of the corresponding datasets used to log metrics.
# global_batch_size: ${model.global_batch_size}
# micro_batch_size: ${model.micro_batch_size}
# shuffle: False
# num_workers: 4
# pin_memory: True
# max_seq_length: 2048
# min_seq_length: 1
# drop_last: False
# context_key: 'input'
# label_key: 'output'
# add_eos: ${model.data.train_ds.add_eos}
# end_string: ${model.data.end_string}
# add_sep: ${model.data.train_ds.add_sep}
# add_bos: ${model.data.train_ds.add_bos}
# separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline}
# write_predictions_to_file: False
# output_file_path_prefix: null # Prefix of the file to write predictions to.
# truncation_field: "context" # Options: ['context', 'answer']
# index_mapping_dir: null # Path to a directory to write index mapping files.
# prompt_template: ${model.data.train_ds.prompt_template}
# # ASR configs
# sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate}

# metric:
# name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
# average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
# num_classes: null

optim:
name: fused_adam
lr: 1e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 50
min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1
constant_steps: 0 # Constant steps should also be 0 when min_lr=0
monitor: val_loss
reduce_on_plateau: false
Loading

0 comments on commit 215cb9d

Please sign in to comment.