|
1 |
| -# Scripts for NeMo SpeechLM |
| 1 | +# Modular SpeechLLM |
2 | 2 |
|
3 |
| -This is the repository of the ICASSP'24 paper [SALM: Speech-augmented Language Model with In-context Learning for Speech Recognition and Translation |
4 |
| -](https://arxiv.org/abs/2310.09424) |
| 3 | +This directory contains example scripts to train and evaluate modular SpeechLLM models [1]. |
5 | 4 |
|
6 |
| -We will release the scripts and checkpoints soon. |
| 5 | +## Requirements |
| 6 | +You will need to install this specific branch of NeMo, or use the provided Dockerfile in the root directory of this repository to build a Docker image with all the necessary dependencies. This branch is based on NeMo main branch by 2/14/2024, while diverging from the main branch in the following ways: |
| 7 | +- Migrating to pytorch_lightning==2.2 to fix some bugs with multiple validation dataloader_iter and saving -last.ckpt files. |
| 8 | +- Pinning to megatron-core==0.4.0 to avoid possible unstable behavior of the latest versions or not well supported NeMo components. |
| 9 | + |
| 10 | + |
| 11 | +## Architecture |
| 12 | + |
| 13 | +In general, there're three main components of a modular SpeechLLM model: |
| 14 | +- An audio encoder that processes the input audio and produces a sequence of audio embeddings. |
| 15 | +- A modality adapter that processes the audio embeddings and produces a sequence of embeddings in the same latent space as the token embeddings of a pretrained large language model (LLM). |
| 16 | +- A pretrained large language model (LLM) that processes embeddings from the modality adapter as well as token embeddings of input prompt, and produces the text output. The audio embeddings and text token embeddings are concatenated in time dimension before going into the LLM. |
| 17 | + |
| 18 | + |
| 19 | +## Usage |
| 20 | + |
| 21 | +### Input Format |
| 22 | + |
| 23 | +You'll need to prepare data in the NeMo manifest format, where each line is a python dictionary with some keys, for example: |
| 24 | +``` |
| 25 | +{ |
| 26 | + "audio_filepath": "path/to/audio.wav", |
| 27 | + "offset": 0.0, # offset of the audio in seconds, this is an optional field |
| 28 | + "duration": 10.0 , # duration of the audio in seconds, can set to `None` to load the whole audio |
| 29 | + "question": "what is the transcription of the audio?", # this is an optional field, see below for more details |
| 30 | + "answer": "the transcription of the audio", # optional for inference |
| 31 | +} |
| 32 | +``` |
| 33 | + |
| 34 | +The `question` field in the manifest is optional, and you can put a list of questions in a file then set `++model.data.train_ds.question_file=<path to to question file>` to ask the dataloader to randomly pick a question from the file for each audio sample. This is useful for training with multiple prompts for the same task. If neither `question` field nor `question_file` is provided, the dataloader will use a default question `what does the audio mean?` for all aduios. |
| 35 | + |
| 36 | + |
| 37 | +### Training |
| 38 | + |
| 39 | +There are several configs for training a SpeechLLM: |
| 40 | +- `conf/modular_audio_gpt_config_peft.yaml`: a config for training a SpeechLLM model with PEFT (e.g., LoRA), where you don't want to tune the whole LLM but still want to adapt the LLM to your needs. |
| 41 | +- `conf/modular_audio_gpt_config_sft.yaml`: a config for training a SpeechLLM model without PEFT, where you might want to tune the whole LLM or simply freeze it and use as is. |
| 42 | +- `conf/modular_audio_gpt_multi_enc_config_peft.yaml`: a config for training a SpeechLLM model with multiple audio encoders and PEFT, where you can add speaker embeddings to the audio embeddings. Currently only TitaNet is supported as the speaker encoder. |
| 43 | + |
| 44 | +With any config, you can set the following flags to control which components to train or freeze: |
| 45 | +- `model.freeze_llm` # Generally set to `True` unless you want to fine-tune the whole LLM. |
| 46 | +- `model.freeze_audio_encoder` # Generally set to `False` unless you want to freeze the audio encoder. |
| 47 | +- `model.freeze_modality_adapter` # Generally set to `False` since we want to train the modality adapter. |
| 48 | + |
| 49 | +In addition to the config file, you will also need two prepare the audio encoder and the LLM as `*.nemo` files. |
| 50 | + |
| 51 | +To train a SpeechLLM model, you can run the following script: |
| 52 | +```bash |
| 53 | +MEGATRON_MODEL=/path/to/megatron-model.nemo |
| 54 | +ASR_MODEL=/path/to/audio-encoder.nemo |
| 55 | + |
| 56 | +TRAIN_MANIFESTS="[/data/train_1.json,/data/train_2.json]" |
| 57 | +VAL_MANIFESTS="[/data/dev_1.json,/data/dev_2.json]" |
| 58 | +VAL_NAMES="[dev-1,dev-2]" |
| 59 | + |
| 60 | +NVTE_FLASH_ATTN=0 \ |
| 61 | +NVTE_FUSED_ATTN=0 \ |
| 62 | +NVTE_MASKED_SOFTMAX_FUSION=0 \ |
| 63 | +CUDA_VISIBLE_DEVICES="0,1" python modular_audio_gpt_train.py --config-path="./conf" --config-name "modular_audio_gpt_config_peft" \ |
| 64 | + trainer.devices=-1 \ |
| 65 | + model.freeze_audio_encoder=True \ |
| 66 | + model.freeze_llm=True \ |
| 67 | + model.global_batch_size=4 \ # global_batch_size = micro_batch_size * num_gpus_per_node * num_nodes * gradient_accumulation_steps |
| 68 | + model.micro_batch_size=2 \ # micro_batch_size = batch_size_per_gpu |
| 69 | + model.pretrained_audio_model=$ASR_MODEL \ |
| 70 | + model.restore_from_path=$MEGATRON_MODEL \ |
| 71 | + model.data.train_ds.manifest_filepath=$TRAIN_MANIFESTS \ |
| 72 | + model.data.validation_ds.manifest_filepath=$VAL_MANIFESTS \ |
| 73 | + ++model.data.validation_ds.names=$VAL_NAMES \ |
| 74 | +``` |
| 75 | + |
| 76 | +You can also use tarred datasets for faster training by converting normal NeMo datasets to tarred datasets using this [script](https://github.com/NVIDIA/NeMo/blob/main/scripts/speech_recognition/convert_to_tarred_audio_dataset.py) and follow the same dataset setting as shown in the script. |
| 77 | + |
| 78 | + |
| 79 | +#### Multi-task training |
| 80 | +In order to use a question file, you can set `++model.data.train_ds.question_file=<path to to question file>` in the command line or use multiple question files with `++model.data.train_ds.question_file=[<path to to question file1>,<path to question file2>,...]`. If the number of question files is equal to the number of provided datasets, the dataloader will assigne each question file to a dataset. Otherwise, the dataloader will randomly pick a question file from all provided question files for each audio sample. Using multiple question files is useful for training with multiple tasks, where each task has its own set of prompts. Meanwhile, you can control the weights for different tasks/datasets by using concatentated tarred datasets, where you can assign weights to datasets by: |
| 81 | +``` |
| 82 | +++model.data.train_ds.is_tarred=True \ |
| 83 | +++model.data.train_ds.is_concat=True \ |
| 84 | +++model.data.train_ds.manifest_filepath=[/path/to/data1/tarred_audio_manifest.json,/path/to/data2/tarred_audio_manifest.json] \ |
| 85 | +++model.data.train_ds.tarred_audio_filepaths=[/path/to/data1/audio__OP_0..1023_CL_.tar,/path/to/data2/audio__OP_0..1023_CL_.tar] \ |
| 86 | +++model.data.train_ds.concat_sampling_technique='random' \ |
| 87 | +++model.data.train_ds.concat_sampling_probabilities=[0.4,0.6] \ |
| 88 | +``` |
| 89 | + |
| 90 | +#### Available Audio Encoders |
| 91 | +Currently all NeMo ASR models are supported, others may also work if they have an `encoder` attribute that returns a sequence of audio embeddings, and a `preprocessor` that takes raw audios and returns a sequence of features for the encoder. The model should also have a `cfg` attribute that returns a `omegaconf.DictConfig` object of model configuration. In addition to a local model, you can also set `pretrained_audio_model` to a model from NGC (e.g., `stt_en_fastconformer_transducer_large`) or Huggingface (e.g., `nvidia/parakeet-rnnt-1.1b`), and the script will download the model and use it for training. |
| 92 | + |
| 93 | + |
| 94 | +### Inference |
| 95 | + |
| 96 | +The config file for inference is `conf/modular_audio_gpt_config_eval.yaml`, where you mainly need to set the `model.data.test_ds` fields. An example of running inference is shown below: |
| 97 | + |
| 98 | +```bash |
| 99 | +ASR_MODEL=/path/to/asr-model.nemo # required only if you freeze the audio encoder during training |
| 100 | +MEGATRON_CKPT=/path/to/megatron-llm.nemo |
| 101 | +ALM_DIR=/path/to/nemo_experiments/job_name |
| 102 | +ALM_YAML=$ALM_DIR/version_0/hparams.yaml |
| 103 | +ALM_CKPT="$ALM_DIR/checkpoints/AudioGPT--validation_wer\=0.2-step\=100000-epoch\=0-last.ckpt" # this checkpoint file only contains the trainable params |
| 104 | + |
| 105 | +VAL_MANIFESTS="[/data/libri-test-other.json,/data/MCV_7.1_test.json,/data/wsj-test.json]" |
| 106 | +VAL_NAMES="[ls-test-other,mcv7.1-test,wsj-test]" |
| 107 | + |
| 108 | +NVTE_MASKED_SOFTMAX_FUSION=0 \ |
| 109 | +NVTE_FLASH_ATTN=0 \ |
| 110 | +NVTE_FUSED_ATTN=0 \ |
| 111 | +CUDA_VISIBLE_DEVICES=0 python modular_audio_gpt_eval.py \ |
| 112 | + model.restore_from_path=$MEGATRON_CKPT \ |
| 113 | + model.pretrained_audio_model=$ASR_MODEL \ # required only if you freeze the audio encoder during training |
| 114 | + model.peft.restore_from_path=$ALM_CKPT \ |
| 115 | + model.peft.restore_from_hparams_path=$ALM_YAML \ |
| 116 | + model.data.test_ds.manifest_filepath=$VAL_MANIFESTS \ |
| 117 | + model.data.test_ds.names=$VAL_NAMES \ |
| 118 | + model.data.test_ds.global_batch_size=8 \ |
| 119 | + model.data.test_ds.micro_batch_size=8 \ |
| 120 | + model.data.test_ds.tokens_to_generate=256 \ |
| 121 | + ++inference.greedy=False \ |
| 122 | + ++inference.top_k=50 \ |
| 123 | + ++inference.top_p=0.95 \ |
| 124 | + ++inference.temperature=0.4 \ |
| 125 | + ++inference.repetition_penalty=1.2 \ |
| 126 | + ++model.data.test_ds.output_dir=${ALM_DIR} |
| 127 | +``` |
| 128 | + |
| 129 | + |
| 130 | +## Reference |
| 131 | +[1] Chen, Z.\*, Huang, H.\*, Andrusenko, A., Hrinchuk, O., Puvvada, K.C., Li, J., Ghosh, S., Balam, J. and Ginsburg, B., 2023. SALM: Speech-augmented Language Model with In-context Learning for Speech Recognition and Translation. ICASSP'24. |
0 commit comments