Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SpeechLM to main #8741

Merged
merged 121 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 89 commits
Commits
Show all changes
121 commits
Select commit Hold shift + click to select a range
d0fd89b
update package info
ericharper Jul 3, 2023
9f117ee
fix the mpt chatbot (#6957)
yidong72 Jul 3, 2023
61ac56c
Remove `compute_on_step` from metrics (#6979)
titu1994 Jul 5, 2023
b5cb783
Hybrid conformer export (#6983)
borisfom Jul 7, 2023
f08cb21
Cache handling without input tensors mutation (#6980)
borisfom Jul 7, 2023
cdf354c
fixes for spellmapper (#6994)
bene-ges Jul 9, 2023
66a85c9
Fixing an issue with confidence ensembles (#6987)
Kipok Jul 10, 2023
de52d95
[TTS] Append pretrained FastPitch & SpectrogamEnhancer pair to availa…
racoiaws Jul 11, 2023
7ff8b08
Add ASR with TTS Tutorial. Fix enhancer usage. (#6955)
artbataev Jul 12, 2023
1647612
install_bs (#7019)
karpnv Jul 13, 2023
fac7103
fix tab text gen (#7022)
yidong72 Jul 13, 2023
49b1aea
TE bug fix (#7027)
dimapihtar Jul 14, 2023
6db4097
Add support for Numba FP16 RNNT Loss (#6991) (#7038)
titu1994 Jul 17, 2023
9af84fb
Remove pyyaml (#7052)
titu1994 Jul 17, 2023
08b26b5
Fix typo and branch in tutorial (#7048)
artbataev Jul 18, 2023
8f3957f
Refined export_config (#7053)
borisfom Jul 18, 2023
c90625e
fix pos id - hf update (#7075)
ekmb Jul 19, 2023
41678cc
Fix documentation for Numba (#7065)
titu1994 Jul 19, 2023
192aa06
small Bugfix (#7079)
fayejf Jul 20, 2023
9ba5277
Fix caching bug in causal convolutions for cache-aware ASR models (#7…
VahidooX Jul 20, 2023
6e9df3d
Adding docs and models for multiple lookahead cache-aware ASR (#7067)
VahidooX Jul 21, 2023
5807931
fix syntax error introduced in PR-7079 (#7102)
bene-ges Jul 25, 2023
511bb75
fix links for TN (#7117)
ekmb Jul 27, 2023
d88c137
Add updated fc ctc and rnnt xxl models (#7128)
nithinraok Jul 29, 2023
8961cdf
update branch (#7135)
ericharper Jul 31, 2023
a5e02a8
Fixed main and merging this to r1.20 (#7127)
tango4j Jul 31, 2023
31c5b3d
fix default attention size (#7141)
nithinraok Aug 1, 2023
d5d600d
Update evaluator.py (#7151)
stevehuang52 Aug 2, 2023
2baef81
Eagerly accumulate embedding grads into fp32 buffer (#6958)
timmoon10 Aug 2, 2023
215cb9d
Modular SpeechLLM implementation for Sept. 2023 submission (SALM) (#7…
zhehuaichen Oct 9, 2023
eed65b7
Add few-shot in-context learning and MLP modality adapter (#7705)
stevehuang52 Oct 13, 2023
2ce0592
update for mlp modality adapter (#7715)
stevehuang52 Oct 13, 2023
8a48480
fix speechllm few-shot inference (#7732)
stevehuang52 Oct 16, 2023
e0b88f2
Add training support for multiple audios in a sample (#7796)
stevehuang52 Oct 24, 2023
209f752
Create README.md
stevehuang52 Jan 12, 2024
01dd0d6
Update README.md
stevehuang52 Jan 12, 2024
528d1bf
Update README.md
stevehuang52 Jan 12, 2024
d94f9dd
update
stevehuang52 Jan 12, 2024
dbad4ac
rename
stevehuang52 Jan 12, 2024
73736ad
update and refactor
stevehuang52 Jan 15, 2024
94bd346
Update SpeechLLM code (#8475)
stevehuang52 Feb 21, 2024
8afd277
Update README.md
stevehuang52 Feb 21, 2024
78c1e8e
update speechllm (#8486)
stevehuang52 Feb 22, 2024
2e74cd1
clean up
stevehuang52 Feb 22, 2024
5ff28a1
update doc and infer
stevehuang52 Feb 23, 2024
e1e825f
update doc
stevehuang52 Feb 23, 2024
99fb448
update doc
stevehuang52 Feb 23, 2024
446c6d9
update doc
stevehuang52 Feb 23, 2024
3d78dd7
update doc
stevehuang52 Feb 23, 2024
0916850
minor update
stevehuang52 Mar 18, 2024
db542b4
fix import
stevehuang52 Mar 18, 2024
fe7214b
clean up
stevehuang52 Mar 18, 2024
916324e
clean up
stevehuang52 Mar 20, 2024
98f86b5
fix pretrained info
stevehuang52 Mar 20, 2024
555a007
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Mar 22, 2024
8f524e3
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Mar 22, 2024
619d75d
update dockerfile
stevehuang52 Mar 22, 2024
c3ca938
update for merging main
stevehuang52 Mar 25, 2024
76db149
fix for merge main
stevehuang52 Mar 25, 2024
f7afea1
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Mar 25, 2024
c99ad43
clean up docs
stevehuang52 Mar 25, 2024
7c9ded7
clean up
stevehuang52 Mar 25, 2024
4c4ac20
clean up
stevehuang52 Mar 25, 2024
afbc212
clean up
stevehuang52 Mar 25, 2024
6bce450
refactor
stevehuang52 Mar 25, 2024
b3f6156
clean up
stevehuang52 Mar 25, 2024
f63b8b8
update
stevehuang52 Mar 25, 2024
9dd72b6
clean up
stevehuang52 Mar 26, 2024
11facc7
fix speechlm test
stevehuang52 Mar 26, 2024
3da8282
update doc
stevehuang52 Mar 26, 2024
179fafd
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 Mar 26, 2024
14c1334
refactor
stevehuang52 Mar 26, 2024
98a0143
refactor
stevehuang52 Mar 27, 2024
7dbe84d
refactor
stevehuang52 Mar 27, 2024
3a039f5
fix multi-layer feat
stevehuang52 Mar 27, 2024
55c9e04
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Mar 27, 2024
073212b
update for webdataset
stevehuang52 Mar 27, 2024
ba86fb9
refactor
stevehuang52 Apr 3, 2024
fdfe7b5
force str to avoid bugs with implicit conversion of str to bool type
stevehuang52 Apr 4, 2024
18b2921
Update examples/multimodal/speech_llm/README.md
stevehuang52 Apr 5, 2024
fef24dc
Update examples/multimodal/speech_llm/README.md
stevehuang52 Apr 5, 2024
c532150
refactor
stevehuang52 Apr 5, 2024
21d4261
Merge branch 'heh/modular_speechllm_pr' of https://github.com/NVIDIA/…
stevehuang52 Apr 5, 2024
c2f6b78
refactor
stevehuang52 Apr 5, 2024
647e184
update for saving nemo
stevehuang52 Apr 5, 2024
36df825
update eval and ngc ckpt
stevehuang52 Apr 5, 2024
f6a90d1
Update nemo/collections/multimodal/speech_llm/data/audio_text_qa_data…
stevehuang52 Apr 8, 2024
d73a684
Update nemo/collections/multimodal/speech_llm/modules/common/audio_te…
stevehuang52 Apr 8, 2024
3dea3ce
Update tests/collections/multimodal/test_speechllm_models.py
stevehuang52 Apr 8, 2024
aa4f85b
refactor and remove nlp adapter mixin assert
stevehuang52 Apr 8, 2024
9e10694
Merge branch 'heh/modular_speechllm_pr' of https://github.com/NVIDIA/…
stevehuang52 Apr 8, 2024
360acd4
remove random context augmentation
stevehuang52 Apr 8, 2024
6449924
fix docstring
stevehuang52 Apr 8, 2024
52617f9
add docstring
stevehuang52 Apr 8, 2024
7c78165
minor refactor
stevehuang52 Apr 11, 2024
ed29843
refactor
stevehuang52 Apr 11, 2024
5a4be92
refactor and fix missing import
stevehuang52 Apr 12, 2024
c991e5b
Merge branch 'main' into heh/modular_speechllm_pr
pablo-garay Apr 13, 2024
79156fc
major refactor on input format and minor update
stevehuang52 Apr 16, 2024
0268898
Merge branch 'heh/modular_speechllm_pr' of https://github.com/NVIDIA/…
stevehuang52 Apr 16, 2024
b6cac3d
fix codeQL
stevehuang52 Apr 17, 2024
8b19dc5
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 Apr 17, 2024
960f958
clean up
stevehuang52 Apr 17, 2024
2e18366
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Apr 24, 2024
8043262
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 6, 2024
55f8231
update for NGC ckpt and refactor
stevehuang52 May 6, 2024
d9e2788
clean up
stevehuang52 May 6, 2024
3cd12e9
Merge branch 'main' into heh/modular_speechllm_pr
ericharper May 7, 2024
30a583a
Merge branch 'main' into heh/modular_speechllm_pr
ericharper May 7, 2024
55b270b
Merge branch 'main' into heh/modular_speechllm_pr
nithinraok May 8, 2024
1c4cbd7
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 9, 2024
3e88457
skip speechlm test until data moved to CI machines
stevehuang52 May 9, 2024
17ab55b
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 10, 2024
6cae145
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 10, 2024
4cfaa30
refactor and update to avoid changing nlp_adapter_mixin
stevehuang52 May 10, 2024
27e33ee
Merge branch 'heh/modular_speechllm_pr' of https://github.com/NVIDIA/…
stevehuang52 May 10, 2024
67ecaa1
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 10, 2024
89926fa
Apply isort and black reformatting
stevehuang52 May 10, 2024
5c7f18e
minor fix
stevehuang52 May 11, 2024
cdcb258
Merge branch 'heh/modular_speechllm_pr' of https://github.com/NVIDIA/…
stevehuang52 May 11, 2024
1afc6e1
Apply isort and black reformatting
stevehuang52 May 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions examples/multimodal/speech_llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Modular SpeechLLM

This directory contains example scripts to train and evaluate modular SpeechLLM (e.g, SALM[1], etc).

## Requirements
You will need to install this specific branch of NeMo, or use the provided Dockerfile in the root directory of this repository to build a Docker image with all the necessary dependencies.

## Architecture

In general, there're three main components of a modular SpeechLLM:
- An audio encoder that processes the input audio and produces a sequence of audio embeddings.
- A modality adapter that processes the audio embeddings and produces a sequence of embeddings in the same latent space as the token embeddings of a pretrained large language model (LLM).
- A pretrained large language model (LLM) that processes embeddings from the modality adapter as well as token embeddings of input prompt, and produces the text output. The audio embeddings and text token embeddings are concatenated in time dimension before going into the LLM.


## Usage

### Input Format

You'll need to prepare data in the NeMo manifest format, where each line is a python dictionary with some keys, for example:
```
{
"audio_filepath": "path/to/audio.wav",
"offset": 0.0, # offset of the audio in seconds, this is an optional field
"duration": 10.0 , # duration of the audio in seconds, can set to `None` to load the whole audio
"question": "what is the transcription of the audio?", # text prompt for the audio, see below for more details
"answer": "the transcription of the audio", # optional for inference, default to "na" in dataloader
stevehuang52 marked this conversation as resolved.
Show resolved Hide resolved
}
```

The `question` field in the manifest is optional, and you can put a list of questions in a question file (one question for each line) then set `++model.data.train_ds.question_file=<path to to question file>` to ask the dataloader to randomly pick a question from the file for each audio sample. This is useful for training with multiple prompts for the same task. If neither `question` field nor `question_file` is provided, the dataloader will use a default question `what does the audio mean?` for all audios. During inference, it is recommended to have the `question` field in the manifest.


### Training

There are several configs for training a SpeechLLM:
- `conf/modular_audio_gpt_config_peft.yaml`: a config for training a SpeechLLM with PEFT (e.g., LoRA), where you don't want to tune the whole LLM but still want to adapt the LLM to your needs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor into more hierarchical structure inside conf/*/**

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

- `conf/modular_audio_gpt_config_sft.yaml`: a config for training a SpeechLLM without PEFT, where you might want to tune the whole LLM or simply freeze it and use as is.
- `conf/modular_audio_gpt_multi_enc_config_peft.yaml`: a config for training a SpeechLLM with multiple audio encoders and PEFT, where you can add speaker embeddings to the audio embeddings. Currently only TitaNet is supported as the speaker encoder.

With any config, you can set the following flags to control which components to train or freeze:
- `model.freeze_llm` # Generally set to `True` unless you want to fine-tune the whole LLM.
- `model.freeze_audio_encoder` # Generally set to `False` unless you want to freeze the audio encoder.
- `model.freeze_modality_adapter` # Generally set to `False` since we want to train the modality adapter.

In addition to the config file, you will also need to prepare the audio encoder and the LLM as `*.nemo` files.

To train a SpeechLLM that uses LoRA, you can run the following script:
```bash
MEGATRON_MODEL=/path/to/megatron-model.nemo
ASR_MODEL=/path/to/audio-model.nemo # only the encoder part will be loaded. e.g, stt_en_fastconformer_transducer_large.nemo

TRAIN_MANIFESTS="[/data/train_1.json,/data/train_2.json]"
VAL_MANIFESTS="[/data/dev_1.json,/data/dev_2.json]"
VAL_NAMES="[dev-1,dev-2]" # names to display when logging validation results for each dataset

CUDA_VISIBLE_DEVICES="0,1" python modular_audio_gpt_train.py --config-path="./conf" --config-name "modular_audio_gpt_config_peft" \
trainer.devices=-1 \
model.freeze_audio_encoder=True \
model.freeze_llm=True \
model.global_batch_size=4 \ # global_batch_size = micro_batch_size * num_gpus_per_node * num_nodes * accumulate_grad_batches
model.micro_batch_size=2 \ # micro_batch_size = batch_size_per_gpu
model.pretrained_audio_model=$ASR_MODEL \
model.restore_from_path=$MEGATRON_MODEL \
Comment on lines +89 to +90
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this naming scheme seems a bit off. Both are pretrained but one has pretrained name and other has restore_from.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes...restore_from_path is a megatron default name and we shouldn't change that...

model.data.train_ds.manifest_filepath=$TRAIN_MANIFESTS \
model.data.validation_ds.manifest_filepath=$VAL_MANIFESTS \
++model.data.validation_ds.names=$VAL_NAMES \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does names correspond to here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names of the datasets, which will be used to display and log the results for each manifest. this is useful when the manifest name does not show the dataset's name. added clarification.

```

You can also use tarred datasets for faster training by converting normal NeMo datasets to tarred datasets using this [script](https://github.com/NVIDIA/NeMo/blob/main/scripts/speech_recognition/convert_to_tarred_audio_dataset.py) and follow the same dataset setting as shown in the script. Also, `accumulate_grad_batches` is automatically set by the model based on `global_batch_size` and `micro_batch_size`, so there's no need to manually calculate and set `trainer.accumulate_grad_batches`.


#### **Multi-task Training**

In order to use a question file, you can set `++model.data.train_ds.question_file=<path to to question file>` in the command line or use multiple question files with `++model.data.train_ds.question_file=[<path to to question file1>,<path to question file2>,...]`. If the number of question files is equal to the number of provided datasets, the dataloader will assigne each question file to a dataset. Otherwise, the dataloader will randomly pick a question file from all provided question files for each audio sample. Using multiple question files is useful for training with multiple tasks, where each task has its own set of prompts. Meanwhile, you can control the weights for different tasks/datasets by using concatentated tarred datasets, where you can assign weights to datasets by:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a question file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's defined in a previous paragraph "The question field in the manifest is optional, and you can put a list of questions in a question file then set ++model.data.train_ds.question_file=<path to to question file> to ask the dataloader to randomly pick a question from the file for each audio sample."

```
++model.data.train_ds.is_tarred=True \
++model.data.train_ds.is_concat=True \
++model.data.train_ds.manifest_filepath=[/path/to/data1/tarred_audio_manifest.json,/path/to/data2/tarred_audio_manifest.json] \
++model.data.train_ds.tarred_audio_filepaths=[/path/to/data1/audio__OP_0..1023_CL_.tar,/path/to/data2/audio__OP_0..1023_CL_.tar] \
++model.data.train_ds.concat_sampling_technique='random' \
++model.data.train_ds.concat_sampling_probabilities=[0.4,0.6] \
```

#### **Available Audio Encoders**

Currently all NeMo ASR models are supported, others may also work if they have an `encoder` attribute that returns a sequence of audio embeddings, and a `preprocessor` that takes raw audios and returns a sequence of features for the encoder. The model should also have a `cfg` attribute that returns a `omegaconf.DictConfig` object of model configuration. In addition to a local model, you can also set `pretrained_audio_model` to a model from NGC (e.g., `stt_en_fastconformer_transducer_large`) or Huggingface (e.g., `nvidia/parakeet-rnnt-1.1b`), and the script will download the model and use it for training.


### Inference

The script you need to perform inference is `modular_audio_gpt_eval.py`, and the corresponding config file is `conf/modular_audio_gpt_config_eval.yaml`, where you mainly need to set the `model.data.test_ds` fields as well as paths to the checkpoints.

#### **Inference with Intermediate Checkpoints**

If you want to perform inference with intermediate checkpoints, where there's no single NeMo checkpoint file that contains all the model parameters, you can use the following script to load each component from its own checkpoint file and perform inference:

```bash
MEGATRON_CKPT=/path/to/megatron-llm.nemo
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While script is fine, in future pr, lets look into minimizing the script as much as possible, and if possible reduce all the unnecessary stuff and have clean pythonic API to do inference

ALM_DIR=/path/to/nemo_experiments/job_name
# below is the path to the config used during training
ALM_YAML=$ALM_DIR/version_0/hparams.yaml
# this checkpoint file only contains the trainable params, the backslash is used to avoid hyrda parsing error
ALM_CKPT="$ALM_DIR/checkpoints/AudioGPT--validation_wer\=0.2-step\=100000-epoch\=0-last.ckpt"

TEST_MANIFESTS="[/data/test_1.json,/data/test_2.json]"
TEST_NAMES="[test-1,test-2]"

CUDA_VISIBLE_DEVICES=0 python modular_audio_gpt_eval.py \
model.restore_from_path=$MEGATRON_CKPT \
model.peft.restore_from_path=$ALM_CKPT \
model.peft.restore_from_hparams_path=$ALM_YAML \
model.data.test_ds.manifest_filepath=$TEST_MANIFESTS \
model.data.test_ds.names=$TEST_NAMES \
model.data.test_ds.metric.name="bleu" \
model.data.test_ds.global_batch_size=8 \
model.data.test_ds.micro_batch_size=8 \
model.data.test_ds.tokens_to_generate=256 \
++inference.greedy=False \
++inference.top_k=50 \
++inference.top_p=0.95 \
++inference.temperature=0.4 \
++inference.repetition_penalty=1.2 \
++model.data.test_ds.output_dir=${ALM_DIR}
```

If you froze the audio encoder during training, you will also need to add the following line to the above script:
```bash
++model.pretrained_audio_model=/path/to/audio/model.nemo
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
```

If you want to save the intermediate checkpoints to a single NeMo checkpoint file, you can add the following line to the above script:
```bash
++save_to_nemo=/path/to/save/model.nemo
```

#### **Inference with Complete SpeechLLM Checkpoints**

If you want to load a trained SpeechLLM from cloud, you can use the following script:
```bash
TEST_MANIFESTS="[/data/test_1.json,/data/test_2.json]"
TEST_NAMES="[test-1,test-2]"

CUDA_VISIBLE_DEVICES=0 python modular_audio_gpt_eval.py \
model.from_pretrained="speechllm_fc_llama2_7b" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this published already?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not yet, but will be soon, before the PR is merged.

model.data.test_ds.manifest_filepath=$TEST_MANIFESTS \
model.data.test_ds.names=$TEST_NAMES \
model.data.test_ds.global_batch_size=8 \
model.data.test_ds.micro_batch_size=8 \
model.data.test_ds.tokens_to_generate=256 \
++inference.greedy=False \
++inference.top_k=50 \
++inference.top_p=0.95 \
++inference.temperature=0.4 \
++inference.repetition_penalty=1.2 \
++model.data.test_ds.output_dir="./test_outputs"
```

If you have a local `.nemo` file, you can use `model.restore_from_path=/path/to/model.nemo` to replace the line `model.from_pretrained="speechllm_fc_llama2_7b"` in the above example.


## Reference
[1] Chen, Z.\*, Huang, H.\*, Andrusenko, A., Hrinchuk, O., Puvvada, K.C., Li, J., Ghosh, S., Balam, J. and Ginsburg, B., 2023. SALM: Speech-augmented Language Model with In-context Learning for Speech Recognition and Translation. ICASSP'24.
142 changes: 142 additions & 0 deletions examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not keep this at the conf/* level - move it to conf/*/** level

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titu1994 do we need to add LICENSE for configs and notebooks?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not to configs usually. Notebooks - hmm not needed there either (at least we haven't done it before)

# this config is used to perform inference on SpeechLLM checkpoints
name: megatron_audio_gpt_eval
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some info on what this config is for.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


trainer:
devices: 1
accelerator: gpu
num_nodes: 1
precision: bf16
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: 9999
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: why not -1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still want to use epoch based training and want the wandb to log epoch info, but megatron doesn't not work properly on LR scheduling with max_steps=-1, so we set a large max_steps to make sure LR is correct and also use max_epoch to log the epoch info

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying, with epoch=-1, there is an issue while logging the epoch number to wandb?

max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10 # frequency with which training steps are logged
val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
gradient_clip_val: 1.0

exp_manager:
explicit_log_dir: null
exp_dir: null
name: ${name}
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: validation_${model.data.validation_ds.metric.name}
save_top_k: 1
mode: min
save_nemo_on_train_end: True
filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}'
model_parallel_size: ${model.tensor_model_parallel_size}
always_save_nemo: True
save_best_model: False

model:
from_pretrained: null # pretrained model name on NGC or HF
restore_from_path: null # Path to an existing .nemo model you wish to add new tasks to or run inference with
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
pretrained_audio_model: null # Path to a .nemo model for audio encoder

seed: 1234
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism

global_batch_size: 1
micro_batch_size: 1
sync_batch_comm: False
megatron_amp_O2: False

## Sequence Parallelism
# Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
sequence_parallel: False

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null # not used with 'selective'
activations_checkpoint_layers_per_pipeline: null
answer_only_loss: False # not used right now
gradient_as_bucket_view: False

hidden_dropout: 0.0
attention_dropout: 0.0
ffn_dropout: 0.0

peft: # keep these basic params for reusing in both sft and peft SpeechLMs
restore_from_path: null
restore_from_hparams_path: null
restore_from_ckpt:
checkpoint_name: null
checkpoint_dir: null


data:
test_ds:
manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
names: null # Names of the corresponding datasets used to log metrics.
global_batch_size: 1
micro_batch_size: 1
shuffle: False
num_workers: 0
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: False
end_string: ${data.train_ds.end_string} # don't change, let hydra resolve from saved config
context_key: ${data.train_ds.context_key} # don't change, let hydra resolve from saved config
label_key: ${data.train_ds.label_key} # don't change, let hydra resolve from saved config
add_eos: ${data.train_ds.add_eos} # don't change, let hydra resolve from saved config
add_sep: ${data.train_ds.add_sep} # don't change, let hydra resolve from saved config
add_bos: ${data.train_ds.add_bos} # don't change, let hydra resolve from saved config
separate_prompt_and_response_with_newline: ${data.train_ds.separate_prompt_and_response_with_newline}
write_predictions_to_file: True
output_file_path_prefix: "preds" # Prefix of the file to write predictions to.
truncation_field: ${data.train_ds.truncation_field} # don't change, let hydra resolve from saved config
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${data.train_ds.prompt_template} # don't change, let hydra resolve from saved config
tokens_to_generate: 512
log_every_n_steps: 1
sample_rate: ${data.train_ds.sample_rate} # don't change, let hydra resolve from saved config
audio_locator: null # set it to allow multiple audios in a sample, e.g. '|audio|', and use it in the context field of manifest to specify the locations of audios (`audio_filepath` is a list of audios).

metric:
name: "bleu" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'wer', 'bleu', 'rouge']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null

save_as_nemo: null # optional string, set to save the whole model into a single nemo file
evaluate_metric: True # set 'true' to calculate metrics (must have labels in data), 'false' to use trainer.predict() to do inference only
inference:
greedy: True # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
all_probs: False # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False
outfile_path: output.txt
compute_attention_mask: True
Loading
Loading