diff --git a/examples/multimodal/speech_llm/README.md b/examples/multimodal/speech_llm/README.md new file mode 100644 index 000000000000..b6a9c7486331 --- /dev/null +++ b/examples/multimodal/speech_llm/README.md @@ -0,0 +1,189 @@ +# Modular SpeechLLM + +This directory contains example scripts to train and evaluate modular SpeechLLM (e.g, SALM[1], etc). + +## Requirements +You will need to install this specific branch of NeMo, or use the provided Dockerfile in the root directory of this repository to build a Docker image with all the necessary dependencies. + +## Architecture + +In general, there're three main components of a modular SpeechLLM: +- An audio encoder that processes the input audio and produces a sequence of audio embeddings. +- A modality adapter that processes the audio embeddings and produces a sequence of embeddings in the same latent space as the token embeddings of a pretrained large language model (LLM). +- A pretrained large language model (LLM) that processes embeddings from the modality adapter as well as token embeddings of input prompt, and produces the text output. The audio embeddings and text token embeddings are concatenated in time dimension before going into the LLM. +- The LLM produces text outputs based on the concatenated input audio and text embedding. + +## Usage + +### Input Format + +You'll need to prepare data in the NeMo manifest format, where each line is a python dictionary with some keys, for example: +``` +{ + "audio_filepath": "path/to/audio.wav", + "offset": 0.0, # offset of the audio in seconds, this is an optional field + "duration": 10.0 , # duration of the audio in seconds, can set to `None` to load the whole audio + "context": "what is the transcription of the audio?", # text prompt for the audio, see below for more details + "answer": "the transcription of the audio", # optional for inference, default to "na" in dataloader +} +``` + +The `context` field in the manifest is optional, and you can put a list of context in a context file (one context for each line) then set `++model.data.train_ds.context_file=` to ask the dataloader to randomly pick a context from the file for each audio sample. This is useful for training with multiple prompts for the same task. If neither `context` field nor `context_file` is provided, the dataloader will use a default context `what does the audio mean?` for all audios. During inference, it is recommended to have the `context` field in the manifest. + +#### **Customizing the fields to use** + +You can also use other fields in the manifest to replace the `context` and `answer`fields, but you'll also need to change the `prompt_template` to use the new field names. For example, if you desire to use the new fields `input_text` and `output_text`, you need to set: +```bash +++model.data.train_ds.context_key=input_text \ +++model.data.train_ds.answer_key=output_text \ +++model.data.train_ds.prompt_template="'Q: {input_text}\nA: {output_text}'" +``` +Note that there're single quotes around the prompt template (to avoid hydra errors), and the field names are wrapped in curly braces. + +#### **Customizing the input format** + +If you would like to use multiple audios, you can set the `audio_filepath` to be a list of audio file paths, and specify the location of each audio by using a special `audio_locator` string in the context. The choice of `audio_locator` should also be passed into the config. For example, if you have a manifest item like this: +``` +{ + "audio_filepath": ["path/to/audio1.wav", "path/to/audio2.wav"], + "context": "what is the transcription of the [audio] and [audio]?", # text prompt for the audio, see below for more details + "answer": "the transcription of the audio1 and audio2", # optional for inference, default to "na" in dataloader +} +``` +You can set the `audio_locator` to be `[audio]` in the config: +```bash +++model.data.train_ds.audio_locator='[audio]' +``` + +By using `audio_locator`, the dataloader will replace the `audio_locator` in the context with the corresponding audio features extracted for each audio. You need to make sure that the number of audio locators in the context matches the number of audio files in the `audio_filepath` field. + +### Training + +There are several configs for training a SpeechLLM: +- `conf/modular_audio_gpt_config_peft.yaml`: a config for training a SpeechLLM with PEFT (e.g., LoRA), where you don't want to tune the whole LLM but still want to adapt the LLM to your needs. +- `conf/modular_audio_gpt_config_sft.yaml`: a config for training a SpeechLLM without PEFT, where you might want to tune the whole LLM or simply freeze it and use as is. +- `conf/modular_audio_gpt_multi_enc_config_peft.yaml`: a config for training a SpeechLLM with multiple audio encoders and PEFT, where you can add speaker embeddings to the audio embeddings. Currently only TitaNet is supported as the speaker encoder. + +With any config, you can set the following flags to control which components to train or freeze: +- `model.freeze_llm` # Generally set to `True` unless you want to fine-tune the whole LLM. +- `model.freeze_audio_encoder` # Generally set to `False` unless you want to freeze the audio encoder. +- `model.freeze_modality_adapter` # Generally set to `False` since we want to train the modality adapter. + +In addition to the config file, you will also need to prepare the audio encoder and the LLM as `*.nemo` files. + +To train a SpeechLLM that uses LoRA, you can run the following script: +```bash +MEGATRON_MODEL=/path/to/megatron-model.nemo +ASR_MODEL=/path/to/audio-model.nemo # only the encoder part will be loaded. e.g, stt_en_fastconformer_transducer_large.nemo + +TRAIN_MANIFESTS="[/data/train_1.json,/data/train_2.json]" +VAL_MANIFESTS="[/data/dev_1.json,/data/dev_2.json]" +VAL_NAMES="[dev-1,dev-2]" # names to display when logging validation results for each dataset + +CUDA_VISIBLE_DEVICES="0,1" python modular_audio_gpt_train.py --config-path="./conf" --config-name "modular_audio_gpt_config_peft" \ + trainer.devices=-1 \ + model.freeze_audio_encoder=True \ + model.freeze_llm=True \ + model.global_batch_size=4 \ # global_batch_size = micro_batch_size * num_gpus_per_node * num_nodes * accumulate_grad_batches + model.micro_batch_size=2 \ # micro_batch_size = batch_size_per_gpu + model.pretrained_audio_model=$ASR_MODEL \ + model.restore_from_path=$MEGATRON_MODEL \ + model.data.train_ds.manifest_filepath=$TRAIN_MANIFESTS \ + model.data.validation_ds.manifest_filepath=$VAL_MANIFESTS \ + ++model.data.validation_ds.names=$VAL_NAMES \ +``` + +You can also use tarred datasets for faster training by converting normal NeMo datasets to tarred datasets using this [script](https://github.com/NVIDIA/NeMo/blob/main/scripts/speech_recognition/convert_to_tarred_audio_dataset.py) and follow the same dataset setting as shown in the script. Also, `accumulate_grad_batches` is automatically set by the model based on `global_batch_size` and `micro_batch_size`, so there's no need to manually calculate and set `trainer.accumulate_grad_batches`. + + +#### **Multi-task Training** + +In order to use a context file, you can set `++model.data.train_ds.context_file=` in the command line or use multiple context files with `++model.data.train_ds.context_file=[,,...]`. If the number of context files is equal to the number of provided datasets, the dataloader will assigne each context file to a dataset. Otherwise, the dataloader will randomly pick a context file from all provided context files for each audio sample. Using multiple context files is useful for training with multiple tasks, where each task has its own set of prompts. Meanwhile, you can control the weights for different tasks/datasets by using concatentated tarred datasets, where you can assign weights to datasets by: +``` +++model.data.train_ds.is_tarred=True \ +++model.data.train_ds.is_concat=True \ +++model.data.train_ds.manifest_filepath=[/path/to/data1/tarred_audio_manifest.json,/path/to/data2/tarred_audio_manifest.json] \ +++model.data.train_ds.tarred_audio_filepaths=[/path/to/data1/audio__OP_0..1023_CL_.tar,/path/to/data2/audio__OP_0..1023_CL_.tar] \ +++model.data.train_ds.concat_sampling_technique='random' \ +++model.data.train_ds.concat_sampling_probabilities=[0.4,0.6] \ +``` + +#### **Available Audio Encoders** + +Currently all NeMo ASR models are supported, others may also work if they have an `encoder` attribute that returns a sequence of audio embeddings, and a `preprocessor` that takes raw audios and returns a sequence of features for the encoder. The model should also have a `cfg` attribute that returns a `omegaconf.DictConfig` object of model configuration. In addition to a local model, you can also set `pretrained_audio_model` to a model from NGC (e.g., `stt_en_fastconformer_transducer_large`) or Huggingface (e.g., `nvidia/parakeet-rnnt-1.1b`), and the script will download the model and use it for training. + + +### Inference + +The script you need to perform inference is `modular_audio_gpt_eval.py`, and the corresponding config file is `conf/modular_audio_gpt_config_eval.yaml`, where you mainly need to set the `model.data.test_ds` fields as well as paths to the checkpoints. + +#### **Inference with Intermediate Checkpoints** + +If you want to perform inference with intermediate checkpoints, where there's no single NeMo checkpoint file that contains all the model parameters, you can use the following script to load each component from its own checkpoint file and perform inference: + +```bash +MEGATRON_CKPT=/path/to/megatron-llm.nemo +ALM_DIR=/path/to/nemo_experiments/job_name +# below is the path to the config used during training +ALM_YAML=$ALM_DIR/version_0/hparams.yaml +# this checkpoint file only contains the trainable params, the backslash is used to avoid hyrda parsing error +ALM_CKPT="$ALM_DIR/checkpoints/AudioGPT--validation_wer\=0.2-step\=100000-epoch\=0-last.ckpt" + +TEST_MANIFESTS="[/data/test_1.json,/data/test_2.json]" +TEST_NAMES="[test-1,test-2]" + +CUDA_VISIBLE_DEVICES=0 python modular_audio_gpt_eval.py \ + model.restore_from_path=$MEGATRON_CKPT \ + model.peft.restore_from_path=$ALM_CKPT \ + model.peft.restore_from_hparams_path=$ALM_YAML \ + model.data.test_ds.manifest_filepath=$TEST_MANIFESTS \ + model.data.test_ds.names=$TEST_NAMES \ + model.data.test_ds.metric.name="bleu" \ + model.data.test_ds.global_batch_size=8 \ + model.data.test_ds.micro_batch_size=8 \ + model.data.test_ds.tokens_to_generate=256 \ + ++inference.greedy=False \ + ++inference.top_k=50 \ + ++inference.top_p=0.95 \ + ++inference.temperature=0.4 \ + ++inference.repetition_penalty=1.2 \ + ++model.data.test_ds.output_dir=${ALM_DIR} +``` + +If you froze the audio encoder during training, you will also need to add the following line to the above script: +```bash +++model.pretrained_audio_model=/path/to/audio/model.nemo +``` + +If you want to save the intermediate checkpoints to a single NeMo checkpoint file, you can add the following line to the above script: +```bash +++save_to_nemo=/path/to/save/model.nemo +``` + +#### **Inference with Complete SpeechLLM Checkpoints** + +If you want to load a trained SpeechLLM from cloud, you can use the following script: +```bash +TEST_MANIFESTS="[/data/test_1.json,/data/test_2.json]" +TEST_NAMES="[test-1,test-2]" + +CUDA_VISIBLE_DEVICES=0 python modular_audio_gpt_eval.py \ + model.from_pretrained="speechllm_fc_llama2_7b" \ + model.data.test_ds.manifest_filepath=$TEST_MANIFESTS \ + model.data.test_ds.names=$TEST_NAMES \ + model.data.test_ds.global_batch_size=8 \ + model.data.test_ds.micro_batch_size=8 \ + model.data.test_ds.tokens_to_generate=256 \ + ++inference.greedy=False \ + ++inference.top_k=50 \ + ++inference.top_p=0.95 \ + ++inference.temperature=0.4 \ + ++inference.repetition_penalty=1.2 \ + ++model.data.test_ds.output_dir="./test_outputs" +``` + +If you have a local `.nemo` file, you can use `model.restore_from_path=/path/to/model.nemo` to replace the line `model.from_pretrained="speechllm_fc_llama2_7b"` in the above example. + + +## Reference +[1] Chen, Z.\*, Huang, H.\*, Andrusenko, A., Hrinchuk, O., Puvvada, K.C., Li, J., Ghosh, S., Balam, J. and Ginsburg, B., 2023. SALM: Speech-augmented Language Model with In-context Learning for Speech Recognition and Translation. ICASSP'24. \ No newline at end of file diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml new file mode 100644 index 000000000000..e2ef61a8046d --- /dev/null +++ b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml @@ -0,0 +1,128 @@ +# this config is used to perform inference on SpeechLLM checkpoints +name: megatron_audio_gpt_eval + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 1 + max_steps: 1000000 + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 1.0 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: True + save_best_model: False + +model: + from_pretrained: null # pretrained model name on NGC or HF + restore_from_path: null # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + pretrained_audio_model: null # Path to a .nemo model for audio encoder + + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + global_batch_size: 1 + micro_batch_size: 1 + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: False # not used right now + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + peft: # keep these basic params for reusing in both sft and peft SpeechLMs + restore_from_path: null + restore_from_hparams_path: null + restore_from_ckpt: + checkpoint_name: null + checkpoint_dir: null + + + data: + test_ds: + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: 1 + micro_batch_size: 1 + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + end_string: ${data.train_ds.end_string} # don't change, let hydra resolve from saved config + context_key: ${data.train_ds.context_key} # don't change, let hydra resolve from saved config + answer_key: ${data.train_ds.answer_key} # don't change, let hydra resolve from saved config + add_eos: ${data.train_ds.add_eos} # don't change, let hydra resolve from saved config + add_sep: ${data.train_ds.add_sep} # don't change, let hydra resolve from saved config + add_bos: ${data.train_ds.add_bos} # don't change, let hydra resolve from saved config + separate_prompt_and_response_with_newline: ${data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: True + output_file_path_prefix: "preds" # Prefix of the file to write predictions to. + truncation_field: ${data.train_ds.truncation_field} # don't change, let hydra resolve from saved config + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${data.train_ds.prompt_template} # don't change, let hydra resolve from saved config + tokens_to_generate: 512 + log_every_n_steps: 1 + sample_rate: ${data.train_ds.sample_rate} # don't change, let hydra resolve from saved config + audio_locator: null # set it to allow multiple audios in a sample, e.g. '|audio|', and use it in the context field of manifest to specify the locations of audios (`audio_filepath` is a list of audios). + + metric: + name: "bleu" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'wer', 'bleu', 'rouge'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + +save_as_nemo: null # optional string, set to save the whole model into a single nemo file + +inference: + greedy: True # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + outfile_path: output.txt + compute_attention_mask: True diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_peft.yaml b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_peft.yaml new file mode 100644 index 000000000000..172a8f37cf1c --- /dev/null +++ b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_peft.yaml @@ -0,0 +1,327 @@ +name: megatron_audio_gpt_peft + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 1000 # used to keep epoch logging correctly, but training will stop based on max_steps + max_steps: 1000000 # 1M steps + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 3000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: ??? + freeze_llm: True + freeze_audio_encoder: False + freeze_modality_adapter: False + + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + peft: + peft_scheme: "lora" # can be either lora, adapter, ia3 or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv','attention_dense','mlp_fc1','mlp_fc2'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: ${model.peft.lora_tuning.adapter_dim} + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + + perception: + use_multi_layer_feat: false # whether to extract multi-layer features, only supports conformer encoder + multi_layer_feat: + layer_idx_list: [0,16] # layer indices to extract features from + aggregator: + mode: "cat" # ways to combine features from different layers, choices=['cat','sum','mean', 'max', 'min'], default to concat ('cat') + pooling: "avg" # ways to pool features if they have different temporal lengths and align_mode=min, choices=['mean', 'max', 'min'] + align_mode: "min" # if features have different temporal lengths, set `min` to pool to the shortest length or `max` to repeat to the longest. + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained audio encoder: + # output_dim: null + # encoder: null + # preprocessor: null + + data: + end_string: "[EOG]" + train_ds: + # Example of how to specify paths to multiple datasets + # manifest_filepath: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'context': 'transcribe this audio', 'answer': 'I have a dream...'} + # the 'answer' field can also be 'text', and a default 'context' field is added if missing in manigests, so as to work with ASR manifests + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + answer_key: 'answer' + add_eos: True + # add_eos: False + end_string: ${model.data.end_string} + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + sample_alpha: null + audio_locator: null + + validation_ds: + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 128 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + audio_locator: ${model.data.train_ds.audio_locator} + + log_every_n_steps: 10 + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'wer', 'bleu', 'rouge'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + # test_ds: + # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + # names: null # Names of the corresponding datasets used to log metrics. + # global_batch_size: ${model.global_batch_size} + # micro_batch_size: ${model.micro_batch_size} + # shuffle: False + # num_workers: 4 + # pin_memory: True + # max_seq_length: 2048 + # min_seq_length: 1 + # drop_last: False + # context_key: 'context' + # answer_key: 'answer' + # add_eos: ${model.data.train_ds.add_eos} + # end_string: ${model.data.end_string} + # add_sep: ${model.data.train_ds.add_sep} + # add_bos: ${model.data.train_ds.add_bos} + # separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + # write_predictions_to_file: False + # output_file_path_prefix: null # Prefix of the file to write predictions to. + # truncation_field: "context" # Options: ['context', 'answer'] + # index_mapping_dir: null # Path to a directory to write index mapping files. + # prompt_template: ${model.data.train_ds.prompt_template} + # # ASR configs + # sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + # metric: + # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + # average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + # num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.001 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 5000 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_sft.yaml b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_sft.yaml new file mode 100644 index 000000000000..7f8512fbb19e --- /dev/null +++ b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_sft.yaml @@ -0,0 +1,299 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: megatron_audio_gpt_sft + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 1000 # used to keep epoch logging correctly, but training will stop based on max_steps + max_steps: 1000000 # 1M steps + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 3000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: ??? + freeze_llm: True + freeze_audio_encoder: True + freeze_modality_adapter: False + + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + perception: + use_multi_layer_feat: false + multi_layer_feat: + layer_idx_list: [0,16] + aggregator: + mode: "cat" # ways to combine features from different layers, choices=['cat','sum','mean', 'max', 'min'], default to concat ('cat') + pooling: "avg" # ways to pool features if they have different temporal lengths and align_mode=min, choices=['mean', 'max', 'min'] + align_mode: "min" # if features have different temporal lengths, set `min` to pool to the shortest length or `max` to repeat to the longest. + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained audio encoder: + # output_dim: null + # encoder: null + # preprocessor: null + + data: + end_string: "[EOG]" + train_ds: + # Example of how to specify paths to multiple datasets + # manifest_filepath: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'context': 'transcribe this audio', 'answer': 'I have a dream...'} + # the 'answer' field can also be 'text', and a default 'context' field is added if missing in manigests, so as to work with ASR manifests + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + answer_key: 'answer' + add_eos: True + # add_eos: False + end_string: ${model.data.end_string} + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + sample_alpha: null + audio_locator: null + + validation_ds: + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 128 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + audio_locator: ${model.data.train_ds.audio_locator} + + log_every_n_steps: 10 + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'wer', 'bleu', 'rouge'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + # test_ds: + # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + # names: null # Names of the corresponding datasets used to log metrics. + # global_batch_size: ${model.global_batch_size} + # micro_batch_size: ${model.micro_batch_size} + # shuffle: False + # num_workers: 4 + # pin_memory: True + # max_seq_length: 2048 + # min_seq_length: 1 + # drop_last: False + # context_key: 'context' + # answer_key: 'answer' + # add_eos: ${model.data.train_ds.add_eos} + # end_string: ${model.data.end_string} + # add_sep: ${model.data.train_ds.add_sep} + # add_bos: ${model.data.train_ds.add_bos} + # separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + # write_predictions_to_file: False + # output_file_path_prefix: null # Prefix of the file to write predictions to. + # truncation_field: "context" # Options: ['context', 'answer'] + # index_mapping_dir: null # Path to a directory to write index mapping files. + # prompt_template: ${model.data.train_ds.prompt_template} + # # ASR configs + # sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + # metric: + # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + # average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + # num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.001 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 5000 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_multi_enc_config_peft.yaml b/examples/multimodal/speech_llm/conf/modular_audio_gpt_multi_enc_config_peft.yaml new file mode 100644 index 000000000000..656e7df287f1 --- /dev/null +++ b/examples/multimodal/speech_llm/conf/modular_audio_gpt_multi_enc_config_peft.yaml @@ -0,0 +1,307 @@ +name: megatron_audio_gpt_multi_enc_peft_tuning + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 10000 # used to keep epoch logging correctly, but training will stop based on max_steps + max_steps: 1000000 # 1M steps + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 3000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + freeze_llm: True + freeze_audio_encoder: True + freeze_modality_adapter: False + + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + peft: + peft_scheme: "lora" # can be either adapter,ia3, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv','attention_dense','mlp_fc1','mlp_fc2'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: ${model.peft.lora_tuning.adapter_dim} + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + + perception: + modality_adapter: + _target_: nemo.collections.multimodal.speech_llm.modules.PoolingMLPConnectors + hidden_dim: 512 + pooling: 'cat' + pooling_factor: 2 + num_layers: 4 + input_dim: -1 + output_dim: -1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoders: + # use `target` instead of `_target_` to avoid auto initialization by hydra, need to do manual instantiation + asr_model: + target: nemo.collections.asr.models.EncDecRNNTBPEModel + model_dim_key: d_model + freeze: True + pretrained_model: stt_en_fastconformer_transducer_large + ssl_model: + target: nemo.collections.asr.models.SpeechEncDecSelfSupervisedModel + model_dim_key: d_model + freeze: True + pretrained_model: ssl_en_conformer_large + use_multi_layer_feat: True + multi_layer_feat: + layer_idx_list: [0,16] + aggregator: + mode: "cat" + pooling: "avg" + rounding: "floor" + + speaker_model: + segment_length_in_secs: 0.4 + freeze: True + pretrained_model: titanet_large + + ref_model: asr_model + aggregator: + mode: "cat" + pooling: "mean" + rounding: "floor" + + # the following are read from the pretrained audio encoder: + # output_dim: null + # encoder: null + # preprocessor: null + + data: + end_string: "[EOG]" + train_ds: + # Example of how to specify paths to multiple datasets + # manifest_filepath: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'context': 'transcribe this audio', 'answer': 'I have a dream...'} + # the 'answer' field can also be 'text', and a default 'context' field is added if missing in manigests, so as to work with ASR manifests + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + answer_key: 'answer' + # add_eos: True + add_eos: False + end_string: ${model.data.end_string} + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + sample_alpha: null + audio_locator: null + + validation_ds: + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 128 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + audio_locator: ${model.data.train_ds.audio_locator} + + log_every_n_steps: 20 + metric: + name: "wer" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + # test_ds: + # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + # names: null # Names of the corresponding datasets used to log metrics. + # global_batch_size: ${model.global_batch_size} + # micro_batch_size: ${model.micro_batch_size} + # shuffle: False + # num_workers: 4 + # pin_memory: True + # max_seq_length: 2048 + # min_seq_length: 1 + # drop_last: False + # context_key: 'context' + # answer_key: 'answer' + # add_eos: ${model.data.train_ds.add_eos} + # end_string: ${model.data.end_string} + # add_sep: ${model.data.train_ds.add_sep} + # add_bos: ${model.data.train_ds.add_bos} + # separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + # write_predictions_to_file: False + # output_file_path_prefix: null # Prefix of the file to write predictions to. + # truncation_field: "context" # Options: ['context', 'answer'] + # index_mapping_dir: null # Path to a directory to write index mapping files. + # prompt_template: ${model.data.train_ds.prompt_template} + # # ASR configs + # sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + # metric: + # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + # average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + # num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.001 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 5000 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/multimodal/speech_llm/conf/salm/salm_config.yaml b/examples/multimodal/speech_llm/conf/salm/salm_config.yaml new file mode 100644 index 000000000000..c49e335c8d66 --- /dev/null +++ b/examples/multimodal/speech_llm/conf/salm/salm_config.yaml @@ -0,0 +1,339 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: salm_fastconformer_gpt_lora_tuning + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 100 + max_steps: 1000000 # 1M steps + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 3000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: stt_en_fastconformer_transducer_large + freeze_llm: True + freeze_audio_encoder: False + freeze_modality_adapter: False + + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + peft: + peft_scheme: "lora" # can be either lora, adapter, ia3 or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv','attention_dense','mlp_fc1','mlp_fc2'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: ${model.peft.lora_tuning.adapter_dim} + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + + perception: + use_multi_layer_feat: false # whether to extract multi-layer features, only supports conformer encoder + multi_layer_feat: + layer_idx_list: [0,16] # layer indices to extract features from + aggregator: + mode: "cat" # ways to combine features from different layers, choices=['cat','sum','mean', 'max', 'min'], default to concat ('cat') + pooling: "avg" # ways to pool features if they have different temporal lengths and align_mode=min, choices=['mean', 'max', 'min'] + align_mode: "min" # if features have different temporal lengths, set `min` to pool to the shortest length or `max` to repeat to the longest. + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained audio encoder: + # output_dim: null + # encoder: null + # preprocessor: null + + data: + end_string: "[EOG]" + train_ds: + # Example of how to specify paths to multiple datasets + # manifest_filepath: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'question': 'transcribe this audio', 'answer': 'I have a dream...'} + # the 'answer' field can also be 'text', and a default 'question' field is added if missing in manigests, so as to work with ASR manifests + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + answer_key: 'answer' + add_eos: True + # add_eos: False + end_string: ${model.data.end_string} + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + # sample_alpha: 0.1 + + validation_ds: + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 128 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + log_every_n_steps: 10 + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'wer', 'bleu', 'rouge'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + # test_ds: + # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + # names: null # Names of the corresponding datasets used to log metrics. + # global_batch_size: ${model.global_batch_size} + # micro_batch_size: ${model.micro_batch_size} + # shuffle: False + # num_workers: 4 + # pin_memory: True + # max_seq_length: 2048 + # min_seq_length: 1 + # drop_last: False + # context_key: 'input' + # answer_key: 'output' + # add_eos: ${model.data.train_ds.add_eos} + # end_string: ${model.data.end_string} + # add_sep: ${model.data.train_ds.add_sep} + # add_bos: ${model.data.train_ds.add_bos} + # separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + # write_predictions_to_file: False + # output_file_path_prefix: null # Prefix of the file to write predictions to. + # truncation_field: "context" # Options: ['context', 'answer'] + # index_mapping_dir: null # Path to a directory to write index mapping files. + # prompt_template: ${model.data.train_ds.prompt_template} + # # ASR configs + # sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + # metric: + # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + # average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + # num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.001 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 2000 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/multimodal/speech_llm/modular_audio_gpt_eval.py b/examples/multimodal/speech_llm/modular_audio_gpt_eval.py new file mode 100644 index 000000000000..d76e479829fa --- /dev/null +++ b/examples/multimodal/speech_llm/modular_audio_gpt_eval.py @@ -0,0 +1,118 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.core.config import hydra_runner +from nemo.utils import logging + +mp.set_start_method("spawn", force=True) + +""" +This is the script to run inference with a ModularAudioGPTModel. + +If you want to evaluate an ModularAudioGPTModel: + +MEGATRON_CKPT=/path/to/megatron-llm.nemo +ALM_DIR=/path/to/nemo_experiments/job_name +ALM_YAML=$ALM_DIR/version_0/hparams.yaml +ALM_CKPT="$ALM_DIR/checkpoints/AudioGPT--validation_wer\=0.5-step\=103-epoch\=0-last.ckpt" + +VAL_MANIFESTS="[/data/libri-test-other.json,/data/MCV_7.1_test.json,/data/wsj-test.json]" +VAL_NAMES="[ls-test-other,mcv7.1-test,wsj-test]" + +HYDRA_FULL_ERROR=1 \ +CUDA_VISIBLE_DEVICES=0 python modular_audio_gpt_eval.py \ + model.restore_from_path=$MEGATRON_CKPT \ + model.peft.restore_from_path=$ALM_CKPT \ + model.peft.restore_from_hparams_path=$ALM_YAML \ + model.data.test_ds.manifest_filepath=$VAL_MANIFESTS \ + model.data.test_ds.names=$VAL_NAMES \ + model.data.test_ds.global_batch_size=8 \ + model.data.test_ds.micro_batch_size=8 \ + model.data.test_ds.tokens_to_generate=256 \ + ++inference.greedy=False \ + ++inference.top_k=50 \ + ++inference.top_p=0.95 \ + ++inference.temperature=0.4 \ + ++inference.repetition_penalty=1.2 \ + ++model.data.test_ds.output_dir=${ALM_DIR} +""" + + +@hydra_runner(config_path="conf", config_name="modular_audio_gpt_config_eval") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + logging.info("**************************************************\n\n") + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + + if cfg.model.from_pretrained: + # Load model from NGC or HuggingFace + logging.info(f"Loading model from cloud: {cfg.model.from_pretrained}") + model_cfg = ModularAudioGPTModel.from_pretrained( + cfg.model.from_pretrained, trainer=trainer, return_config=True + ) + model_cfg = ModularAudioGPTModel.merge_inference_cfg(cfg, trainer, model_cfg) + model_file = ModularAudioGPTModel.from_pretrained( + cfg.model.from_pretrained, trainer=trainer, return_model_file=True + ) + model = ModularAudioGPTModel.restore_from( + restore_path=model_file, + trainer=trainer, + override_config_path=model_cfg, + strict=False, + map_location="cpu", + ) + if "peft" in model_cfg and model_cfg.peft.get("peft_scheme", None): + # need this due to the way that MegatronGPTSFTModel doesn't load adapters in model initialization + model.load_adapters(model_file, map_location="cpu") + else: + # Load model from a local file + model_cfg = ModularAudioGPTModel.merge_inference_cfg(cfg, trainer) + model = ModularAudioGPTModel.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + override_config_path=model_cfg, + strict=False, + map_location="cpu", + ) + model = ModularAudioGPTModel.load_adapters_for_inference(cfg, model_cfg, model) + model = ModularAudioGPTModel.load_audio_encoder_for_inference(cfg, model_cfg, model) + + model.freeze() + if cfg.get("save_as_nemo", None): + model.setup("predict") # need to call setup() to load adapters and prepare for saving + model.save_to(cfg.save_as_nemo) + logging.info(f"Model saved to {Path(cfg.save_as_nemo).absolute()}, exiting...") + exit(0) + + if not cfg.model.get('use_flash_attention', False): + cfg.inference.compute_attention_mask = True + config = OmegaConf.to_container(cfg.inference, resolve=True) + model.set_inference_config(config) + + # run inference + trainer.test(model) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/speech_llm/modular_audio_gpt_train.py b/examples/multimodal/speech_llm/modular_audio_gpt_train.py new file mode 100644 index 000000000000..04bff37e7a3f --- /dev/null +++ b/examples/multimodal/speech_llm/modular_audio_gpt_train.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +mp.set_start_method("spawn", force=True) + +""" +MEGATRON_CKPT=/path/to/megatron-llm.nemo +ASR_MODEL=/path/to/asr-model.nemo + +TRAIN_MANIFESTS="[/data/train_1.json,/data/train_2.json]" +VAL_MANIFESTS="[/data/dev_1.json,/data/dev_2.json]" +VAL_NAMES="[dev-1,dev-2]" + +CUDA_VISIBLE_DEVICES="0,1" python modular_audio_gpt_train.py --config-path="./conf" --config-name "modular_audio_gpt_config_peft" \ + trainer.devices=-1 \ + model.freeze_audio_encoder=True \ + model.freeze_llm=True \ + model.global_batch_size=4 \ + model.micro_batch_size=2 \ + model.pretrained_audio_model=$ASR_MODEL \ + model.restore_from_path=$MEGATRON_MODEL \ + model.data.train_ds.manifest_filepath=$TRAIN_MANIFESTS \ + model.data.validation_ds.manifest_filepath=$VAL_MANIFESTS \ + ++model.data.validation_ds.names=$VAL_NAMES \ +""" + + +@hydra_runner(config_path="conf", config_name="modular_audio_gpt_config_peft") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + precision = cfg.trainer.precision + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + cfg.trainer.precision = precision + + exp_manager(trainer, cfg.exp_manager) + # update resume from checkpoint found by exp_manager + logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') + + model = ModularAudioGPTModel.restore_from_pretrained_models(cfg, trainer=trainer) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index b9642b3ea5dc..d0e014e42a37 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -16,7 +16,7 @@ import random from collections import OrderedDict from dataclasses import dataclass -from typing import List, Optional, Set +from typing import List, Optional, Set, Tuple import torch import torch.distributed @@ -356,7 +356,9 @@ def __init__( if reduction and reduction_factor > 1: assert reduction_position >= -1 and reduction_position < n_layers self.reduction_subsampling = SubsamplingReductionModule( - reduction=reduction, d_model=d_model, reduction_factor=reduction_factor, + reduction=reduction, + d_model=d_model, + reduction_factor=reduction_factor, ) self.reduction_position = reduction_position else: @@ -804,15 +806,15 @@ def setup_streaming_params( max_context: int = 10000, ): """ - This function sets the needed values and parameters to perform streaming. The configuration would be stored in self.streaming_cfg. - The streaming configuration is needed to simulate streaming inference. - - Args: - chunk_size (int): overrides the chunk size - shift_size (int): overrides the shift size for chunks - left_chunks (int): overrides the number of left chunks visible to each chunk - max_context (int): the value used for the cache size of last_channel layers if left context is set to infinity (-1) - Defaults to -1 (means feat_out is d_model) + This function sets the needed values and parameters to perform streaming. The configuration would be stored in self.streaming_cfg. + The streaming configuration is needed to simulate streaming inference. + + Args: + chunk_size (int): overrides the chunk size + shift_size (int): overrides the shift size for chunks + left_chunks (int): overrides the number of left chunks visible to each chunk + max_context (int): the value used for the cache size of last_channel layers if left context is set to infinity (-1) + Defaults to -1 (means feat_out is d_model) """ streaming_cfg = CacheAwareStreamingConfig() @@ -903,12 +905,19 @@ def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None create_tensor = torch.zeros last_time_cache_size = self.conv_context_size[0] cache_last_channel = create_tensor( - (len(self.layers), batch_size, self.streaming_cfg.last_channel_cache_size, self.d_model,), + ( + len(self.layers), + batch_size, + self.streaming_cfg.last_channel_cache_size, + self.d_model, + ), device=device, dtype=dtype, ) cache_last_time = create_tensor( - (len(self.layers), batch_size, self.d_model, last_time_cache_size), device=device, dtype=dtype, + (len(self.layers), batch_size, self.d_model, last_time_cache_size), + device=device, + dtype=dtype, ) if max_dim > 0: cache_last_channel_len = torch.randint( @@ -934,7 +943,6 @@ def change_attention_model( update_config: bool = True, device: torch.device = None, ): - """ Update the self_attention_model which changes the positional encoding and attention layers. @@ -1053,7 +1061,7 @@ def change_attention_model( def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int): """ - Update the conv_chunking_factor (int) + Update the conv_chunking_factor (int) Default is 1 (auto) Set it to -1 (disabled) or to a specific value (power of 2) if you OOM in the conv subsampling layers @@ -1098,7 +1106,9 @@ def _update_adapter_cfg_input_dim(self, cfg: DictConfig): cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) return cfg - def get_accepted_adapter_types(self,) -> Set[type]: + def get_accepted_adapter_types( + self, + ) -> Set[type]: types = super().get_accepted_adapter_types() if len(types) == 0: @@ -1113,6 +1123,85 @@ def get_accepted_adapter_types(self,) -> Set[type]: return types +class ConformerMultiLayerFeatureExtractor(NeuralModule, Exportable, AccessMixin): + """ + A wrapper module that extracts features from multiple layers of a ConformerEncoder, + by reusing existing mechanisim for interctc loss. + To use it, set `layer_idx_list` to specify the indices of layers to extract from. + Also, you can specify an `aggretator` module to aggregate the features from different layers, default not aggregating. + """ + + def __init__( + self, + encoder: ConformerEncoder, + layer_idx_list: List[int], + aggregator: NeuralModule = None, + detach: bool = False, + convert_to_cpu: bool = False, + ): + super().__init__() + self.encoder = encoder + self.layer_idx_list = [int(l) for l in layer_idx_list] + for x in self.layer_idx_list: + if x < 0 or x >= len(encoder.layers): + raise ValueError(f"layer index {x} out of range [0, {len(encoder.layers)})") + self.enc_access_cfg = { + "interctc": { + "capture_layers": self.layer_idx_list, + }, + "detach": detach, + "convert_to_cpu": convert_to_cpu, + } + self.aggregator = aggregator + + def forward( + self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + ) -> Tuple[torch.Tensor, torch.Tensor]: + old_access_flag = self.is_access_enabled(guid=getattr(self, "model_guid", None)) + self.update_access_cfg(self.enc_access_cfg, guid=getattr(self, "model_guid", None)) + self.set_access_enabled(access_enabled=True, guid=getattr(self, "model_guid", None)) + + _ = self.encoder( + audio_signal=audio_signal, + length=length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + ### chunk of code adapted from ConformerEncoder.forward_internal() + total_registry = {} + for module_registry in self.get_module_registry(self.encoder).values(): + for key in module_registry: + if key.startswith("interctc/") and key in total_registry: + raise RuntimeError(f"layer {key} has been logged multiple times!") + total_registry.update(module_registry) + + encoded_list = [] + encoded_len_list = [] + for layer_idx in self.layer_idx_list: + try: + layer_outputs = total_registry[f"interctc/layer_output_{layer_idx}"] + layer_lengths = total_registry[f"interctc/layer_length_{layer_idx}"] + except KeyError: + raise RuntimeError( + f"Intermediate layer {layer_idx} was not captured! Check the layer index and the number of ConformerEncoder layers." + ) + if len(layer_outputs) > 1 or len(layer_lengths) > 1: + raise RuntimeError("Make sure encoder.forward is called exactly one time") + encoded_list.append(layer_outputs[0]) # [B, D, T] + encoded_len_list.append(layer_lengths[0]) # [B] + + self.encoder.reset_registry() + self.set_access_enabled(access_enabled=old_access_flag, guid=getattr(self, "model_guid", None)) + ### end of adapted chunk + + if self.aggregator is not None: + return self.aggregator(encoded_list, encoded_len_list) # Tensor[B,D*L,T], Tensor[B] + else: + return encoded_list, encoded_len_list # List[Tensor[B,D,T]], List[Tensor[B]] + + """ Register any additional information """ diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 5a71679607be..c252d498dc08 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -67,18 +67,18 @@ class TranscribeConfig: _internal: Optional[InternalTranscribeConfig] = None -def move_to_device(batch, device): +def move_to_device(batch, device, non_blocking=False): """ Recursively move all tensors in `batch` to `device`. """ if isinstance(batch, torch.Tensor): - return batch.to(device) + return batch.to(device, non_blocking=non_blocking) elif isinstance(batch, (list, tuple)): - return [move_to_device(x, device) for x in batch] + return [move_to_device(x, device, non_blocking) for x in batch] elif isinstance(batch, dict): - return {k: move_to_device(v, device) for k, v in batch.items()} + return {k: move_to_device(v, device, non_blocking) for k, v in batch.items()} else: - raise TypeError(f"Unsupported type: {type(batch)}") + return batch # do nothing if not supported type def get_value_from_transcription_config(trcfg, key, default): diff --git a/nemo/collections/common/data/dataset.py b/nemo/collections/common/data/dataset.py index c2c29b54f7f6..71220dd9d5f2 100644 --- a/nemo/collections/common/data/dataset.py +++ b/nemo/collections/common/data/dataset.py @@ -26,12 +26,12 @@ class ConcatDataset(IterableDataset): """ - A dataset that accepts as argument multiple datasets and then samples from them based on the specified + A dataset that accepts as argument multiple datasets and then samples from them based on the specified sampling technique. Args: datasets (list): A list of datasets to sample from. - shuffle (bool): Whether to shuffle individual datasets. Only works with non-iterable datasets. + shuffle (bool): Whether to shuffle individual datasets. Only works with non-iterable datasets. Defaults to True. sampling_technique (str): Sampling technique to choose which dataset to draw a sample from. Defaults to 'temperature'. Currently supports 'temperature', 'random' and 'round-robin'. @@ -73,7 +73,9 @@ def __init__( self.sampling_kwargs['seed'] = seed elif sampling_technique == 'random': self.index_generator = ConcatDataset.random_generator - self.sampling_kwargs['p'] = sampling_probabilities + self.sampling_kwargs['p'] = ( + sampling_probabilities if sampling_probabilities else [1 / len(datasets)] * len(datasets) + ) self.sampling_kwargs['seed'] = seed elif sampling_technique == 'round-robin': self.index_generator = ConcatDataset.round_robin_generator @@ -200,7 +202,7 @@ def random_generator(datasets, **kwargs): class ConcatMapDataset(Dataset): """ - A dataset that accepts as argument multiple datasets and then samples from them based on the specified + A dataset that accepts as argument multiple datasets and then samples from them based on the specified sampling technique. Args: @@ -300,7 +302,7 @@ class CodeSwitchedDataset(IterableDataset): Args: datasets (list): A list of datasets lang_probs (list): A list of probabilities (which must sum to 1) corresponding to the sampling probability for each dataset - shuffle (bool): Whether to shuffle individual datasets. Only works with non-iterable datasets. + shuffle (bool): Whether to shuffle individual datasets. Only works with non-iterable datasets. Defaults to True. min_duration (int): the minimum duration (secs) of each synthetic code-switched sample. Will draw randomly until this is hit. Defaults to 4 @@ -535,7 +537,7 @@ def build_single_CS_sample(self): wav = np.trim_zeros(wav) # normalise to provided DB level - wav_norm = wav * (10.0 ** (self.db_norm / 20.0) / np.maximum(0.01, (wav ** 2).mean(axis=0) ** 0.5)) + wav_norm = wav * (10.0 ** (self.db_norm / 20.0) / np.maximum(0.01, (wav**2).mean(axis=0) ** 0.5)) # this part appends the normed waveform to the existing waveform, and inserts pause_join amount of silence # if necessary, otherwise just a straight append diff --git a/nemo/collections/common/metrics/__init__.py b/nemo/collections/common/metrics/__init__.py index 322e62214ead..9e21d93816a9 100644 --- a/nemo/collections/common/metrics/__init__.py +++ b/nemo/collections/common/metrics/__init__.py @@ -14,5 +14,9 @@ from nemo.collections.common.metrics.classification_accuracy import TopKClassificationAccuracy from nemo.collections.common.metrics.global_average_loss_metric import GlobalAverageLossMetric -from nemo.collections.common.metrics.metric_string_to_torchmetric import MetricStringToTorchMetric +from nemo.collections.common.metrics.metric_string_to_torchmetric import ( + ClassificationMetricsSet, + MetricStringToTorchMetric, + TextMetricsSet, +) from nemo.collections.common.metrics.perplexity import Perplexity diff --git a/nemo/collections/common/metrics/metric_string_to_torchmetric.py b/nemo/collections/common/metrics/metric_string_to_torchmetric.py index b38047b576cc..f91c915309f2 100644 --- a/nemo/collections/common/metrics/metric_string_to_torchmetric.py +++ b/nemo/collections/common/metrics/metric_string_to_torchmetric.py @@ -13,11 +13,13 @@ # limitations under the License. from torchmetrics import Accuracy, AveragePrecision, F1Score, MatthewsCorrCoef, PearsonCorrCoef, SpearmanCorrCoef +from torchmetrics.text import SacreBLEUScore from torchmetrics.text.rouge import ROUGEScore +from torchmetrics.text.wer import WordErrorRate from nemo.collections.common.metrics.classification_accuracy import ExactStringMatchMetric, TokenF1Score -__all__ = ['MetricStringToTorchMetric'] +__all__ = ['MetricStringToTorchMetric', 'TextMetricsSet', 'ClassificationMetricsSet'] # Dictionary that maps a metric string name to its corresponding torchmetric class. @@ -31,4 +33,10 @@ 'matthews_corr_coef': MatthewsCorrCoef, 'exact_string_match': ExactStringMatchMetric, 'rouge': ROUGEScore, + 'wer': WordErrorRate, + 'bleu': SacreBLEUScore, } + +TextMetricsSet = set(['rouge', 'wer', 'bleu']) + +ClassificationMetricsSet = set(['accuracy', 'average_precision', 'f1', 'exact_string_match']) diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 66def034400f..24ca6cffe458 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -17,11 +17,11 @@ import os from itertools import combinations from typing import Any, Dict, Iterable, List, Optional, Union - +import numpy as np import pandas as pd from nemo.collections.common.parts.preprocessing import manifest, parsers -from nemo.utils import logging +from nemo.utils import logging, logging_mode class _Collection(collections.UserList): @@ -320,7 +320,13 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): **kwargs: Kwargs to pass to `AudioText` constructor. """ - ids, audio_files, durations, texts, offsets, = ( + ( + ids, + audio_files, + durations, + texts, + offsets, + ) = ( [], [], [], @@ -343,6 +349,19 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): ) +class SpeechLLMAudioTextEntity(object): + def __init__(self, sid, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) -> None: + self.id = sid + self.audio_file = audio_file + self.duration = duration + self.context = context + self.answer = answer + self.offset = offset + self.speaker = speaker + self.orig_sr = orig_sr + self.lang = lang + + class ASRVideoText(VideoText): """`VideoText` collector from cv structured json files.""" @@ -356,7 +375,13 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): **kwargs: Kwargs to pass to `VideoText` constructor. """ - ids, video_files, durations, texts, offsets, = ( + ( + ids, + video_files, + durations, + texts, + offsets, + ) = ( [], [], [], @@ -379,10 +404,272 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): ) +class SpeechLLMAudioText(object): + """List of audio-transcript text correspondence with preprocessing. + + All of the audio, duration, context, answer are optional. + If answer is not present, text is treated as the answer. + """ + + def __init__( + self, + ids: List[int], + audio_files: List[str], + durations: List[float], + context_list: List[str], + answers: List[str], + offsets: List[str], + speakers: List[Optional[int]], + orig_sampling_rates: List[Optional[int]], + langs: List[Optional[str]], + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + max_num_samples: Optional[int] = None, + ): + """Instantiates audio-context-answer manifest with filters and preprocessing. + + + Args: + ids: List of examples positions. + audio_files: List of audio files. + durations: List of float durations. + context_list: List of raw text transcripts. + answers: List of raw text transcripts. + offsets: List of duration offsets or None. + speakers: List of optional speakers ids. + orig_sampling_rates: List of original sampling rates of audio files. + langs: List of language ids, one for eadh sample, or None. + min_duration: Minimum duration to keep entry with (default: None). + max_duration: Maximum duration to keep entry with (default: None). + max_number: Maximum number of samples to collect. + do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id. + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data. + """ + + data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0 + if index_by_file_id: + self.mapping = {} + + for id_, audio_file, duration, offset, context, answer, speaker, orig_sr, lang in zip( + ids, audio_files, durations, offsets, context_list, answers, speakers, orig_sampling_rates, langs + ): + # Duration filters. + if duration is not None: + curr_min_dur = min(duration) if isinstance(duration, list) else duration + curr_max_dur = max(duration) if isinstance(duration, list) else duration + curr_sum_dur = sum(duration) if isinstance(duration, list) else duration + if min_duration is not None and curr_min_dur < min_duration: + duration_filtered += curr_sum_dur + num_filtered += 1 + continue + + if max_duration is not None and curr_max_dur > max_duration: + duration_filtered += curr_sum_dur + num_filtered += 1 + continue + total_duration += curr_sum_dur + + if answer is None: + duration_filtered += curr_sum_dur + num_filtered += 1 + continue + + data.append( + SpeechLLMAudioTextEntity(id_, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) + ) + if index_by_file_id and audio_file is not None: + file_id, _ = os.path.splitext(os.path.basename(audio_file)) + if file_id not in self.mapping: + self.mapping[file_id] = [] + self.mapping[file_id].append(len(data) - 1) + + # Max number of entities filter. + if len(data) == max_number: + break + + if max_num_samples is not None and not index_by_file_id: + if max_num_samples <= len(data): + logging.info(f"Subsampling dataset from {len(data)} to {max_num_samples} samples") + data = data[:max_num_samples] + else: + logging.info(f"Oversampling dataset from {len(data)} to {max_num_samples} samples") + data = data * (max_num_samples // len(data)) + res_num = max_num_samples % len(data) + res_data = [data[idx] for idx in np.random.choice(len(data), res_num, replace=False)] + data.extend(res_data) + elif max_num_samples is not None and index_by_file_id: + logging.warning("Tried to subsample dataset by max_num_samples, but cannot since index_by_file_id is set.") + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + + self.data = data + + def __getitem__(self, idx): + if idx < 0 or idx > len(self.data): + raise ValueError(f"index out of range [0,{len(self.data)}), got {idx} instead") + return self.data[idx] + + def __len__(self): + return len(self.data) + + +class SpeechLLMAudioTextCollection(SpeechLLMAudioText): + """`SpeechLLMAudioText` collector from SpeechLLM json files. + + This collector also keeps backward compatibility with SpeechLLMAudioText. + """ + + def __init__( + self, + manifests_files: Union[str, List[str]], + context_file: Optional[Union[List[str], str]] = None, + context_key: str = "context", + answer_key: str = "answer", + *args, + **kwargs, + ): + """Parse lists of audio files, durations and transcripts texts. + + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + *args: Args to pass to `AudioText` constructor. + **kwargs: Kwargs to pass to `AudioText` constructor. + """ + self.context_key = context_key + self.answer_key = answer_key + + ( + ids, + audio_files, + durations, + context_list, + answers, + offsets, + ) = ( + [], + [], + [], + [], + [], + [], + ) + speakers, orig_srs, langs = ( + [], + [], + [], + ) + if context_file is not None: + question_file_list = context_file.split(",") if isinstance(context_file, str) else context_file + self.context_list = [] + for filepath in question_file_list: + with open(filepath, 'r') as f: + for line in f.readlines(): + line = line.strip() + if line: + self.context_list.append(line) + logging.info(f"Use random text context from {context_file} for {manifests_files}") + else: + self.context_list = None + + for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item): + ids.append(item['id']) + audio_files.append(item['audio_file']) + durations.append(item['duration']) + context_list.append(item['context']) + answers.append(item['answer']) + offsets.append(item['offset']) + speakers.append(item['speaker']) + orig_srs.append(item['orig_sr']) + langs.append(item['lang']) + super().__init__( + ids, audio_files, durations, context_list, answers, offsets, speakers, orig_srs, langs, *args, **kwargs + ) + + def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: + item = json.loads(line) + + # Audio file + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + elif 'audio_file' not in item: + item['audio_file'] = None + + # If the audio path is a relative path and does not exist, + # try to attach the parent directory of manifest to the audio path. + # Revert to the original path if the new path still doesn't exist. + # Assume that the audio path is like "wavs/xxxxxx.wav". + if item['audio_file'] is not None: + item['audio_file'] = manifest.get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file) + + # Duration. + if 'duration' not in item: + item['duration'] = None + + # Answer. + if self.answer_key in item: + item['answer'] = item.pop(self.answer_key) + elif 'text' in item: + # compatability with ASR manifests that uses 'text' as answer key + item['answer'] = item.pop('text') + elif 'text_filepath' in item: + with open(item.pop('text_filepath'), 'r') as f: + item['answer'] = f.read() + else: + item['answer'] = "na" + + # context. + if self.context_key in item: + item['context'] = item.pop(self.context_key) + elif 'context_filepath' in item: + with open(item.pop('context_filepath'), 'r') as f: + item['context'] = f.read() + elif self.context_list is not None: + context = np.random.choice(self.context_list).strip() + item['context'] = context + elif 'question' in item: + # compatability with old manifests that uses 'question' as context key + logging.warning( + f"Neither `{self.context_key}` is found nor `context_file` is set, but found `question` in item: {item}", + mode=logging_mode.ONCE, + ) + item['context'] = item.pop('question') + else: + # default context if nothing is found + item['context'] = "what does this audio mean" + + item = dict( + audio_file=item['audio_file'], + duration=item['duration'], + context=str(item['context']), + answer=str(item['answer']), + offset=item.get('offset', None), + speaker=item.get('speaker', None), + orig_sr=item.get('orig_sample_rate', None), + lang=item.get('lang', None), + ) + return item + + class SpeechLabel(_Collection): """List of audio-label correspondence with preprocessing.""" - OUTPUT_TYPE = collections.namedtuple(typename='SpeechLabelEntity', field_names='audio_file duration label offset',) + OUTPUT_TYPE = collections.namedtuple( + typename='SpeechLabelEntity', + field_names='audio_file duration label offset', + ) def __init__( self, @@ -532,7 +819,10 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: class FeatureSequenceLabel(_Collection): """List of feature sequence of label correspondence with preprocessing.""" - OUTPUT_TYPE = collections.namedtuple(typename='FeatureSequenceLabelEntity', field_names='feature_file seq_label',) + OUTPUT_TYPE = collections.namedtuple( + typename='FeatureSequenceLabelEntity', + field_names='feature_file seq_label', + ) def __init__( self, @@ -614,9 +904,11 @@ class ASRFeatureSequenceLabel(FeatureSequenceLabel): """`FeatureSequenceLabel` collector from asr structured json files.""" def __init__( - self, manifests_files: Union[str, List[str]], max_number: Optional[int] = None, index_by_file_id: bool = False, + self, + manifests_files: Union[str, List[str]], + max_number: Optional[int] = None, + index_by_file_id: bool = False, ): - """Parse lists of feature files and sequences of labels. Args: @@ -655,7 +947,10 @@ def _parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: f"Manifest file has invalid json line " f"structure: {line} without proper seq_label key." ) - item = dict(feature_file=item['feature_file'], seq_label=item['seq_label'],) + item = dict( + feature_file=item['feature_file'], + seq_label=item['seq_label'], + ) return item @@ -759,7 +1054,8 @@ def __init__( data.sort(key=lambda entity: entity.duration) logging.info( - "Filtered duration for loading collection is %f.", duration_filtered, + "Filtered duration for loading collection is %f.", + duration_filtered, ) logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") @@ -937,8 +1233,7 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: class Audio(_Collection): - """Prepare a list of all audio items, filtered by duration. - """ + """Prepare a list of all audio items, filtered by duration.""" OUTPUT_TYPE = collections.namedtuple(typename='Audio', field_names='audio_files duration offset text') @@ -999,11 +1294,14 @@ def __init__( class AudioCollection(Audio): - """List of audio files from a manifest file. - """ + """List of audio files from a manifest file.""" def __init__( - self, manifest_files: Union[str, List[str]], audio_to_manifest_key: Dict[str, str], *args, **kwargs, + self, + manifest_files: Union[str, List[str]], + audio_to_manifest_key: Dict[str, str], + *args, + **kwargs, ): """Instantiates a list of audio files loaded from a manifest file. @@ -1045,6 +1343,7 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: Returns: Dictionary with audio_files, duration, and offset. """ + # Local utility function def get_audio_file(item: Dict, manifest_key: Union[str, List[str]]): """Get item[key] if key is string, or a list @@ -1117,7 +1416,10 @@ def get_audio_file(item: Dict, manifest_key: Union[str, List[str]]): class FeatureLabel(_Collection): """List of feature sequence and their label correspondence with preprocessing.""" - OUTPUT_TYPE = collections.namedtuple(typename='FeatureLabelEntity', field_names='feature_file label duration',) + OUTPUT_TYPE = collections.namedtuple( + typename='FeatureLabelEntity', + field_names='feature_file label duration', + ) def __init__( self, @@ -1194,7 +1496,6 @@ def __init__( *args, **kwargs, ): - """Parse lists of feature files and sequences of labels. Args: @@ -1383,7 +1684,14 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): **kwargs: Kwargs to pass to `AudioText` constructor. """ - ids, feature_files, rttm_files, durations, texts, offsets, = ( + ( + ids, + feature_files, + rttm_files, + durations, + texts, + offsets, + ) = ( [], [], [], diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index b686322c0882..aed05673f6fa 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -28,7 +28,7 @@ class SentencePieceTokenizer(TokenizerSpec): """ Sentencepiecetokenizer https://github.com/google/sentencepiece. - + Args: model_path: path to sentence piece tokenizer model. To create the model use create_spt_model() special_tokens: either list of special tokens or dictionary of token name to token value @@ -87,7 +87,7 @@ def text_to_tokens(self, text): return self.tokenizer.encode_as_pieces(text) - def text_to_ids(self, text): + def text_to_ids(self, text, sample_alpha=None): if self.legacy: ids = [] idx = 0 @@ -115,7 +115,10 @@ def text_to_ids(self, text): ids.extend(self.tokenizer.encode_as_ids(text[idx:])) return ids - return self.tokenizer.encode_as_ids(text) + if sample_alpha is not None: + return self.tokenizer.encode_as_ids(text, enable_sampling=True, alpha=sample_alpha, nbest_size=-1) + else: + return self.tokenizer.encode_as_ids(text) def tokens_to_text(self, tokens): if isinstance(tokens, np.ndarray): diff --git a/nemo/collections/multimodal/speech_llm/__init__.py b/nemo/collections/multimodal/speech_llm/__init__.py new file mode 100644 index 000000000000..f0c19a3eebb9 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.multimodal.speech_llm import models, modules diff --git a/nemo/collections/multimodal/speech_llm/data/__init__.py b/nemo/collections/multimodal/speech_llm/data/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py b/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py new file mode 100644 index 000000000000..7d0ee6afbfa2 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py @@ -0,0 +1,1327 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import io +import os +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +import webdataset as wds +from omegaconf import DictConfig, ListConfig, open_dict + +from nemo.collections.asr.data.audio_to_text import ( + VALID_FILE_FORMATS, + cache_datastore_manifests, + expand_sharded_filepaths, + shard_manifests_if_needed, +) +from nemo.collections.asr.data.audio_to_text_dataset import ConcatDataset, convert_to_config_list, get_chain_dataset +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.common.parts.preprocessing import collections +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import ( + ceil_to_nearest, + get_num_samples_from_files, + maybe_cast_to_list, +) +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( + get_datasets_weights_and_num_samples, +) +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.core.classes import Dataset, IterableDataset +from nemo.utils import logging, logging_mode +from nemo.utils.distributed import webdataset_split_by_workers + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + HAVE_MEGATRON_CORE = False + +__all__ = [ + 'AudioTextDataset', + 'TarredAudioTextDataset', + 'get_tarred_audio_text_dataset_from_config', + 'get_audio_text_dataset_from_config', +] + + +def _audio_collate_fn(audio_signals, audio_lengths): + """collate batch of audio sig, audio len, tokens, tokens len + Args: + audio_signals: List[Tensor] + audio_lengths: List[Tensor] + """ + + max_audio_len = 0 + has_audio = audio_lengths[0] is not None + if has_audio: + max_audio_len = max(audio_lengths).item() + + audio_signals_padded = [] + for sig, sig_len in zip(audio_signals, audio_lengths): + if has_audio: + sig_len = sig_len.item() + if sig_len < max_audio_len: + pad = (0, max_audio_len - sig_len) + sig = torch.nn.functional.pad(sig, pad) + audio_signals_padded.append(sig) + + if has_audio: + audio_signals_padded = torch.stack(audio_signals_padded) + audio_lengths = torch.stack(audio_lengths) + else: + audio_signals_padded, audio_lengths = None, None + + return audio_signals_padded, audio_lengths + + +def _build_loss_mask(processed_example: Dict, answer_only_loss: bool = True): + """Pad input_ids in batch to max batch length while building loss mask""" + # function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py + input_ids = processed_example['input_ids'] + answer_start_idx = processed_example['answer_start_idx'] + if answer_only_loss: + loss_mask = [float(idx >= answer_start_idx) for idx in range(len(input_ids))] + else: + loss_mask = [1.0] * len(input_ids) + + return loss_mask + + +def _collate_item(item: Union[torch.Tensor, np.ndarray, List], max_length: int, pad_id: int = 0): + # function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py + item = maybe_cast_to_list(item) + # max_length = max([len(x) for x in item]) if item else 0 + # here [0] should be tokenizer.pad_id + item = [x + [pad_id] * (max_length - len(x)) for x in item] + return item + + +def _speechllm_audio_text_collate_fn( + batch: Dict, + tokens_to_generate: int, + pad_to_max_length: bool, + max_seq_length: int, + text_pad_id: int, +): + sample_ids = [x["idx"] for x in batch] + sample_ids = torch.tensor(sample_ids, dtype=torch.int32) + + audio_signal = [x["audio_signal"] for x in batch] + audio_lengths = [x["audio_length"] for x in batch] + audio_signal, audio_lengths = _audio_collate_fn(audio_signal, audio_lengths) + + input_ids = [item['input_ids'][:-1] for item in batch] + labels = [item['input_ids'][1:] for item in batch] + contexts = [item['context_ids'] for item in batch] + context_lengths = torch.LongTensor([item['context_length'] for item in batch]) + answers = [item['answer_ids'] for item in batch] + + loss_mask = [_build_loss_mask(item)[1:] for item in batch] + + max_length = max([len(x) for x in input_ids]) + tokens_to_generate + # increase max length to nearest multiple of 4 or 8 + if pad_to_max_length: + max_length = max_seq_length + else: + max_length = min(max_seq_length, ceil_to_nearest(max_length, 8)) + assert max_length <= max_seq_length + + position_ids = [list(range(max_length)) for _ in batch] + position_ids = torch.LongTensor(position_ids) + input_ids = torch.LongTensor(_collate_item(input_ids, max_length=max_length, pad_id=text_pad_id)) + input_length = torch.LongTensor([len(x) for x in input_ids]) + labels = torch.LongTensor(_collate_item(labels, max_length=max_length, pad_id=text_pad_id)) + loss_mask = torch.LongTensor(_collate_item(loss_mask, max_length=max_length, pad_id=0)) + contexts = torch.LongTensor(_collate_item(contexts, max_length=max_length, pad_id=text_pad_id)) + answers = torch.LongTensor(_collate_item(answers, max_length=max_length, pad_id=text_pad_id)) + + batch = { + 'sample_ids': sample_ids, + 'audio_signal': audio_signal, + 'audio_signal_length': audio_lengths, + 'tokens': input_ids, + 'tokens_length': input_length, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'contexts': contexts, + 'context_lengths': context_lengths, + 'answers': answers, + 'max_length': torch.LongTensor(max_length), + 'metadata': [x['metadata'] for x in batch], + } + + return batch + + +def _speechllm_multi_audio_text_collate_fn( + batch: Dict, + tokens_to_generate: int, + pad_to_max_length: bool, + max_seq_length: int, + text_pad_id: int, +): + """Collate function for multi audio case.""" + context_start_idx = [item['context_start_idx'] for item in batch] + + audio_signals = [x["audio_signal"] for x in batch] + audio_lengths = [x["audio_length"] for x in batch] + num_audios = [len(x) for x in audio_signals] + + # put all audios from all samples in one batch + audio_signals_merged = [item for audio_list in audio_signals for item in audio_list] + audio_lengths_merged = [item for length_list in audio_lengths for item in length_list] + audio_signals_merged, audio_lengths_merged = _audio_collate_fn(audio_signals_merged, audio_lengths_merged) + + for i in range(len(batch)): + # create dummy audio_signal and audio_length for _speechllm_audio_text_collate_fn() + batch[i]["audio_signal"] = audio_signals[i][0] + batch[i]["audio_length"] = audio_lengths[i][0] + + batch = _speechllm_audio_text_collate_fn(batch, tokens_to_generate, pad_to_max_length, max_seq_length, text_pad_id) + + # add multi audio specific fields + batch['context_start_idx'] = list(context_start_idx) + batch['num_audios'] = torch.LongTensor(num_audios) + batch['audio_signal'] = audio_signals_merged + batch['audio_signal_length'] = audio_lengths_merged + + return batch + + +class TextProcessing(object): + """ + Text processing pipeline for AudioTextDataset and TarredAudioTextDataset. + This class is adapted from the one used in nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py + """ + + def __init__( + self, + tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + add_sep: bool = False, + sep_id: Optional[int] = None, + seed: int = 1234, + separate_prompt_and_response_with_newline: bool = False, + answer_only_loss: bool = True, + truncation_field: str = "answer", + pad_to_max_length: bool = False, # (@adithyare) allows for much faster training especially in PEFT settings. + prompt_template: str = None, + virtual_tokens: int = 0, + tokens_to_generate: int = 0, + context_key: str = 'context', + answer_key: str = 'answer', + end_string: Optional[str] = None, + sample_alpha: Optional[float] = None, + audio_locator: Optional[str] = None, + ): + self.context_key = context_key + self.answer_key = answer_key + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.min_seq_length = min_seq_length + self.seed = seed + self.separate_prompt_and_response_with_newline = separate_prompt_and_response_with_newline + self.answer_only_loss = answer_only_loss + self.truncation_field = truncation_field + self.pad_to_max_length = pad_to_max_length + self.prompt_template = prompt_template + self.virtual_tokens = virtual_tokens + self.tokens_to_generate = tokens_to_generate + self.add_bos = add_bos + self.add_eos = add_eos + self.add_sep = add_sep + self.end_string = end_string + self.sample_alpha = sample_alpha + self.audio_locator = audio_locator + + if add_bos and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0: + self.bos_id = tokenizer.bos_id + else: + self.bos_id = None + + if add_eos and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0: + self.eos_id = tokenizer.eos_id + else: + self.eos_id = None + + if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0: + self.pad_id = tokenizer.pad_id + else: + self.pad_id = self.eos_id if self.eos_id is not None else 0 + + self.sep_id = sep_id if add_sep else None + + if self.prompt_template is not None: + # When providing things like newlines in the prompt template via the CLI, they are escaped. This line unescapes them. + self.prompt_template = self.prompt_template.encode('utf-8').decode('unicode_escape') + assert self.truncation_field in ["answer", "context"] + + def _process_example(self, context: str, output: str): + """ + Create an example by concatenating text and answer. + Truncation is carried out when needed, but it is performed only on the prompt side. + BOS, EOS, and SEP, are added if specified. + + function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py + """ + if self.prompt_template is not None: + if self.context_key not in self.prompt_template or self.answer_key not in self.prompt_template: + if "input" in self.prompt_template and "output" in self.prompt_template: + logging.warning( + f"Using 'input' and 'output' as context and answer keys, since given ones ({self.context_key}, {self.answer_key}) are not found in the prompt template: {self.prompt_template}.", + mode=logging_mode.ONCE, + ) + self.context_key = "input" + self.answer_key = "output" + assert f'{{{self.context_key}}}' in self.prompt_template + assert f'{{{self.answer_key}}}' in self.prompt_template + # Make sure that '{output}' always occurs at the end of the prompt template string + assert self.prompt_template.index(f'{{{self.answer_key}}}') == len(self.prompt_template) - len( + f'{{{self.answer_key}}}' + ) + # Get the context by replacing only the input + original_context = context + context = ( + self.prompt_template.replace(f'{{{self.context_key}}}', context) + .replace(f'{{{self.answer_key}}}', '') + .strip(' ') + ) + # Replace the input and output placeholders with the actual input and output + text = self.prompt_template.replace(f'{{{self.context_key}}}', original_context).replace( + f'{{{self.answer_key}}}', output + ) + + elif self.separate_prompt_and_response_with_newline: + text = context + '\n' + output + else: + text = context + ' ' + output + + if self.virtual_tokens: + # (@adithyare) we are going to insert "pad/eos" tokens in the beginning of the text and context + # these pad/eos tokens are placeholders for virtual tokens + pre_pad = [self.tokenizer.eos_id] * self.virtual_tokens + else: + pre_pad = [] + answer_text = text[len(context) :] + answer_ids = pre_pad + self.tokenizer.text_to_ids(answer_text, self.sample_alpha) + if self.end_string: + answer_ids += self.tokenizer.text_to_ids(self.end_string) + + if self.audio_locator is None: + # signle audio case + context_ids = self.tokenizer.text_to_ids(context) + context_start_idx = [0] + else: + # multiple audio case + context_ids = [] + context_start_idx = [] + for context_seg in context.split(self.audio_locator): + context_start_idx.append(len(context_ids)) + context_ids.extend(self.tokenizer.text_to_ids(context_seg)) + context_ids = pre_pad + context_ids + context_start_idx = [x + len(pre_pad) for x in context_start_idx] + + # for the long context cases, collate_fn includes self.tokens_to_generate for padding + total_ids = len(context_ids) + max(len(answer_ids), self.tokens_to_generate) + if self.add_bos: + total_ids += 1 + if self.add_sep: + total_ids += 1 + # Only training need to consider eos token + if self.add_eos and self.tokens_to_generate == 0: + total_ids += 1 + + # If the total number of token is greater than the max, we will try to truncate the answer + if total_ids > self.max_seq_length: + truncation_length = total_ids - self.max_seq_length + if self.truncation_field == "answer": + answer_ids = answer_ids[: -min(truncation_length, len(answer_ids))] + elif self.truncation_field == "context": + context_ids = context_ids[: -min(truncation_length, len(context_ids))] + + input_ids = context_ids + answer_start_idx = len(input_ids) + + # Adds bos token in the start + if self.add_bos: + context_ids = [self.tokenizer.bos_id] + context_ids + input_ids = [self.tokenizer.bos_id] + input_ids + answer_start_idx += 1 + + # Adds sep token between text/prompt and answer + if self.add_sep: + context_ids = context_ids + [self.sep_id] + input_ids = input_ids + [self.sep_id] + answer_start_idx += 1 + + input_ids = input_ids + answer_ids + + # Only training need to consider eos token + if self.add_eos and self.tokens_to_generate == 0: + input_ids = input_ids + [self.tokenizer.eos_id] + + if len(input_ids) > self.max_seq_length: + logging.warning(f'Input ids length {len(input_ids)} exceed max sequence length {self.max_seq_length}') + input_ids = input_ids[: self.max_seq_length] + + processed_example = { + 'input_ids': input_ids, + 'answer_start_idx': answer_start_idx, + 'context_ids': context_ids, + 'context_length': len(context_ids), + 'answer_ids': answer_ids, + 'context_start_idx': context_start_idx, + } + + return processed_example + + +class AudioTextDataset(TextProcessing, Dataset): + """ + Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds). + Each new line is a different sample. Example below: + {"audio_filepath": "1.wav", "duration": 1.12, "question": "what is the capital of France?", "answer": "Paris"} + {"audio_filepath": "2.wav", "duration": 2.15, "question": "what is the capital of Italy?", "answer": "Rome"} + Args: + manifest_filepath: Path to manifest json as described above. Can be comma-separated paths. + tokenizer: text tokenizer object + sample_rate (int): Sample rate to resample loaded audio to + int_values (bool): If true, load samples as 32-bit integers. Defauts to False. + augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor object used to augment loaded + audio + max_duration: If audio exceeds this length, do not include in dataset + min_duration: If audio is less than this length, do not include in dataset + max_utts: Limit number of utterances + trim: whether or not to trim silence. Defaults to False + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + --------- NLP SPECIFIC ARGS ------------- + max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + add_bos (bool): Whether to add a beginning of sentence token to each data example + add_eos (bool): Whether to add an end of sentence token to each data example + add_sep (bool): Whether to add a separation token to each data example (goes between prompt and answer) + tokens_to_generate (int): (inference only) Number of tokens to generate during inference + seed: Random seed for data shuffling. + max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. + seed: int = 1234, + context_key: Key to use for the context in your JSONL file + answer_key: Key to use for the label in your JSONL file + separate_prompt_and_response_with_newline: Adds a newline between prompt and response. + answer_only_loss: If True, will compute the loss only on the answer part of the input. If False, will compute the loss on the entire input. + truncation_field: Field to use for truncation. (Options: "answer", "context"). Field to be used for truncation if the combined length exceeds the max sequence length. + pad_to_max_length: Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. + prompt_template: Prompt template to inject via an fstring. Formatted like Q: {input}\n\nA: {output} + end_string: Optional[str] = None, if not None, add this string to the end of the answer. + --------------- additional args for misc purposes ---------------- + context_file: Optional[Union[List[str], str]] = None, if provided, will use this file to load random questions from, if question is not in manifest. + sample_alpha: Optional[float] = None, for SPE subword sampling + audio_locator: Optional[str] = None, a special string to split the context into multiple audio segments. + """ + + def __init__( + self, + manifest_filepath: str, + tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', + sample_rate: int, + int_values: bool = False, + augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None, + max_duration: Optional[int] = None, + min_duration: Optional[int] = None, + max_utts: int = 0, + trim: bool = False, + channel_selector: Optional[ChannelSelectorType] = None, + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + add_sep: bool = False, + sep_id: Optional[int] = None, + max_num_samples: Optional[int] = None, + seed: int = 1234, + separate_prompt_and_response_with_newline: bool = False, + answer_only_loss: bool = True, + truncation_field: str = "answer", + pad_to_max_length: bool = False, # (@adithyare) allows for much faster training especially in PEFT settings. + prompt_template: str = None, + virtual_tokens: int = 0, + tokens_to_generate: int = 0, + index_by_file_id: bool = False, + context_key: str = 'context', + answer_key: str = 'answer', + end_string: Optional[str] = None, + context_file: Optional[Union[List[str], str]] = None, + sample_alpha: Optional[float] = None, + audio_locator: Optional[str] = None, + ): + super().__init__( + tokenizer=tokenizer, + max_seq_length=max_seq_length, + min_seq_length=min_seq_length, + add_bos=add_bos, + add_eos=add_eos, + add_sep=add_sep, + sep_id=sep_id, + seed=seed, + separate_prompt_and_response_with_newline=separate_prompt_and_response_with_newline, + answer_only_loss=answer_only_loss, + truncation_field=truncation_field, + pad_to_max_length=pad_to_max_length, + prompt_template=prompt_template, + virtual_tokens=virtual_tokens, + tokens_to_generate=tokens_to_generate, + context_key=context_key, + answer_key=answer_key, + end_string=end_string, + sample_alpha=sample_alpha, + audio_locator=audio_locator, + ) + + if isinstance(manifest_filepath, str): + manifest_filepath = manifest_filepath.split(",") + + # If necessary, cache manifests and audio from object store + cache_datastore_manifests(manifest_filepaths=manifest_filepath, cache_audio=True) + + self.collection = collections.SpeechLLMAudioTextCollection( + manifests_files=manifest_filepath, + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + index_by_file_id=index_by_file_id, + max_num_samples=max_num_samples, + context_file=context_file, + context_key=context_key, + answer_key=answer_key, + ) + + self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) + self.trim = trim + self.channel_selector = channel_selector + + def get_manifest_sample(self, sample_id): + return self.collection[sample_id] + + def __getitem__(self, index): + output = {"idx": index} + sample = self.collection[index] + offset = sample.offset + + if offset is None: + offset = 0 + + if sample.audio_file is not None: + features = self.featurizer.process( + sample.audio_file, + offset=offset, + duration=sample.duration, + trim=self.trim, + orig_sr=sample.orig_sr, + channel_selector=self.channel_selector, + ) + f, fl = features, torch.tensor(features.shape[0]).long() + output["audio_signal"] = f + output["audio_length"] = fl + else: + # dummy features + output["audio_signal"] = torch.zeros([80]) + # accomodates normalize_batch + output["audio_length"] = torch.tensor(80) + + text_data = self._process_example(context=sample.context, output=sample.answer) + + output.update(text_data) + output['metadata'] = { + 'audio_filepath': sample.audio_file, + 'offset': offset, + 'duration': sample.duration, + } + return output + + def __len__(self): + return len(self.collection) + + def _collate_fn(self, batch): + return _speechllm_audio_text_collate_fn( + batch=batch, + tokens_to_generate=self.tokens_to_generate, + pad_to_max_length=self.pad_to_max_length, + max_seq_length=self.max_seq_length, + text_pad_id=self.pad_id, + ) + + def collate_fn(self, batch): + # override collate_fn to skip type checking + return self._collate_fn(batch) + + +class MultiAudioTextDataset(AudioTextDataset): + """ + Dataset for having multi audios per sample, for example in few-shot in-context learning. + To use this dataset, you need to specify the `audio_locator` field in the dataset config, + and use that to specify the locations of the audio files in your manifest. In this case, + the `audio_filepath` field in the manifest is a list of audio filepaths, and the `duration` + field is a list of durations, one for each audio file. The `offset` field is optional, and + if not specified, it is assumed to be 0.0. The `offset` field is also a list of offsets if specified. + + Example manifest item for audio_locator='|audio|': + { + "audio_filepath": ["1.wav","2.wav","3.wav"], + "duration": [1.05,1.05,2.0], + "answer": "this was her dream as nearly as she could recall it", + "question": "Following are examples of speech audios and their transcriptions. + Example 1: audio is |audio|, transcription is 'I have a dream'. + Example 2: audio is |audio|, transcription is ' I don't have a dream'. + Given the following audio |audio|, transcribe the audio into words." + } + """ + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def _collate_fn(self, batch): + return _speechllm_multi_audio_text_collate_fn( + batch=batch, + tokens_to_generate=self.tokens_to_generate, + pad_to_max_length=self.pad_to_max_length, + max_seq_length=self.max_seq_length, + text_pad_id=self.pad_id, + ) + + def __getitem__(self, index): + output = {"idx": index} + sample = self.collection[index] + offsets = sample.offset if sample.offset else 0.0 + durations = sample.duration if sample.duration else 0.0 + num_audios = 0 + output["audio_signal"] = [] + output["audio_length"] = [] + if sample.audio_file is not None: + audio_list = sample.audio_file + if isinstance(sample.audio_file, str): + audio_list = [sample.audio_file] + if not isinstance(audio_list, list): + raise ValueError( + f"The field `audio_file` must be either a str or a list of str, but got type {type(sample.audio_file)} instead" + ) + + num_audios = len(audio_list) + if isinstance(durations, list) and len(durations) != num_audios: + raise ValueError( + f"The number of durations ({len(durations)}) must match the number of audio clips ({num_audios})" + ) + if isinstance(offsets, list) and len(offsets) != num_audios: + raise ValueError( + f"The number of offsets ({len(offsets)}) must match the number of audio clips ({num_audios})" + ) + + for i, audio_file in enumerate(audio_list): + duration = durations[i] if isinstance(durations, list) else 0 + offset = offsets[i] if isinstance(offsets, list) else 0 + features = self.featurizer.process( + audio_file, + offset=offset, + duration=duration, + trim=self.trim, + orig_sr=sample.orig_sr, + channel_selector=self.channel_selector, + ) + f, fl = features, torch.tensor(features.shape[0]).long() + output["audio_signal"].append(f) + output["audio_length"].append(fl) + else: + # dummy features + output["audio_signal"] = [torch.zeros([8])] + # accomodates normalize_batch + output["audio_length"] = [torch.tensor(8)] + + text_data = self._process_example(context=sample.context, output=sample.answer) + + if isinstance(output["audio_signal"], list) and len(output["audio_signal"]) + 1 != len( + text_data['context_start_idx'] + ): + raise ValueError( + f"The number of text segments ({len(text_data['context_start_idx'])}) must be one more than number of audios ({len(output['audio_signal'])})" + ) + + output.update(text_data) + output['metadata'] = { + 'audio_filepath': sample.audio_file, + 'offset': offsets, + 'duration': sample.duration, + } + return output + + +class TarredAudioFilter: + def __init__(self, collection, iterator): + self.iterator = iterator + self.collection = collection + + def __iter__(self): + return self + + def __next__(self): + while True: + audio_bytes, audio_filename = next(self.iterator) + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + if file_id in self.collection.mapping: + return audio_bytes, audio_filename + + +class TarredAudioLoopOffsets: + def __init__(self, collection, iterator): + self.iterator = iterator + self.collection = collection + self.current_fn = None + self.current_bytes = None + self.offset_id = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.current_fn is None: + self.current_bytes, self.current_fn = next(self.iterator) + self.offset_id = 0 + else: + offset_list = self.collection.mapping[self.current_fn] + if len(offset_list) == self.offset_id + 1: + self.current_bytes, self.current_fn = next(self.iterator) + self.offset_id = 0 + else: + self.offset_id += 1 + + return self.current_bytes, self.current_fn, self.offset_id + + +class TarredAudioTextDataset(TextProcessing, IterableDataset): + """ + A similar Dataset to the AudioTextDataset, but which loads tarred audio files. + + Accepts a single comma-separated JSON manifest file (in the same style as for the AudioTextDataset), + as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should + contain the information for one audio file, including at least the transcript and name of the audio + file within the tarball. + + Valid formats for the audio_tar_filepaths argument include: + (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or + (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...]. + + Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference. + This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements. + Supported opening braces - { <=> (, [, < and the special tag _OP_. + Supported closing braces - } <=> ), ], > and the special tag _CL_. + For SLURM based tasks, we suggest the use of the special tags for ease of use. + + See the WebDataset documentation for more information about accepted data and input formats. + + If using multiple workers the number of shards should be divisible by world_size to ensure an + even split among workers. If it is not divisible, logging will give a warning but training will proceed. + In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering + is applied. We currently do not check for this, but your program may hang if the shards are uneven! + + Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest + after filtering. An incorrect manifest length may lead to some DataLoader issues down the line. + + Args: + audio_tar_filepaths: Either a list of audio tarball filepaths, or a + string (can be brace-expandable). + manifest_filepath (str): Path to the manifest. + parser (callable): A callable which is used to pre-process the text output. + sample_rate (int): Sample rate to resample loaded audio to + int_values (bool): If true, load samples as 32-bit integers. Defauts to False. + augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor + object used to augment loaded audio + shuffle_n (int): How many samples to look ahead and load to be shuffled. + See WebDataset documentation for more details. + Defaults to 0. + min_duration (float): Dataset parameter. + All training files which have a duration less than min_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to 0.1. + max_duration (float): Dataset parameter. + All training files which have a duration more than max_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to None. + blank_index (int): Blank character index, defaults to -1. + unk_index (int): Unknown character index, defaults to -1. + normalize (bool): Dataset parameter. + Whether to use automatic text cleaning. + It is highly recommended to manually clean text for best results. + Defaults to True. + trim (bool): Whether to use trim silence from beginning and end + of audio signal using librosa.effects.trim(). + Defaults to False. + bos_id (id): Dataset parameter. + Beginning of string symbol id used for seq2seq models. + Defaults to None. + eos_id (id): Dataset parameter. + End of string symbol id used for seq2seq models. + Defaults to None. + pad_id (id): Token used to pad when collating samples in batches. + If this is None, pads using 0s. + Defaults to None. + shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp. + - `scatter`: The default shard strategy applied by WebDataset, where each node gets + a unique set of shards, which are permanently pre-allocated and never changed at runtime. + - `replicate`: Optional shard strategy, where each node gets all of the set of shards + available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. + The benefit of replication is that it allows each node to sample data points from the entire + dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. + + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. + shard_manifests (bool): Whether or not to try / shard manifests. Defaults to False. + global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. + world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. + --------- NLP SPECIFIC ARGS ------------- + max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + add_bos (bool): Whether to add a beginning of sentence token to each data example + add_eos (bool): Whether to add an end of sentence token to each data example + add_sep (bool): Whether to add a separation token to each data example (goes between prompt and answer) + tokens_to_generate (int): (inference only) Number of tokens to generate during inference + seed: Random seed for data shuffling. + seed: int = 1234, + context_key: Key to use for the context in your JSONL file + answer_key: Key to use for the label in your JSONL file + separate_prompt_and_response_with_newline: Adds a newline between prompt and response. + answer_only_loss: If True, will compute the loss only on the answer part of the input. If False, will compute the loss on the entire input. + truncation_field: Field to use for truncation. (Options: "answer", "context"). Field to be used for truncation if the combined length exceeds the max sequence length. + pad_to_max_length: Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. + prompt_template: Prompt template to inject via an fstring. Formatted like Q: {input}\n\nA: {output} + end_string: Optional[str] = None, if not None, add this string to the end of the answer. + --------------- additional args for misc purposes ---------------- + context_file: Optional[Union[List[str], str]] = None, if provided, will use this file to load random questions from, if question is not in manifest. + sample_alpha: Optional[float] = None, for SPE subword sampling + """ + + def __init__( + self, + audio_tar_filepaths: Union[str, List[str]], + manifest_filepath: str, + tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', + sample_rate: int, + int_values: bool = False, + augmentor: Optional['nemo.collections.asr.parts.perturb.AudioAugmentor'] = None, + shuffle_n: int = 0, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + trim: bool = False, + shard_strategy: str = "scatter", + shard_manifests: bool = False, + global_rank: int = 0, + world_size: int = 0, + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + add_sep: bool = False, + sep_id: int = None, + seed: int = 1234, + separate_prompt_and_response_with_newline: bool = False, + answer_only_loss: bool = True, + truncation_field: str = "answer", # choices=["answer", "context"] + pad_to_max_length: bool = False, # (@adithyare) allows for much faster training especially in PEFT settings. + prompt_template: str = None, + virtual_tokens: int = 0, + tokens_to_generate: int = 0, + context_key: str = 'context', + answer_key: str = 'answer', + end_string: Optional[str] = None, + context_file: Optional[Union[List[str], str]] = None, + sample_alpha: Optional[float] = None, + ): + super().__init__( + tokenizer=tokenizer, + max_seq_length=max_seq_length, + min_seq_length=min_seq_length, + add_bos=add_bos, + add_eos=add_eos, + add_sep=add_sep, + sep_id=sep_id, + seed=seed, + separate_prompt_and_response_with_newline=separate_prompt_and_response_with_newline, + answer_only_loss=answer_only_loss, + truncation_field=truncation_field, + pad_to_max_length=pad_to_max_length, + prompt_template=prompt_template, + virtual_tokens=virtual_tokens, + tokens_to_generate=tokens_to_generate, + context_key=context_key, + answer_key=answer_key, + end_string=end_string, + sample_alpha=sample_alpha, + ) + self.is_megatron_iterable = True + self.shard_manifests = shard_manifests + + # Shard manifests if necessary and possible and then expand the paths + manifest_filepath = shard_manifests_if_needed( + shard_manifests=shard_manifests, + shard_strategy=shard_strategy, + manifest_filepaths=manifest_filepath, + world_size=world_size, + global_rank=global_rank, + ) + + # If necessary, cache manifests from object store + cache_datastore_manifests(manifest_filepaths=manifest_filepath) + + self.collection = collections.SpeechLLMAudioTextCollection( + manifests_files=manifest_filepath, + min_duration=min_duration, + max_duration=max_duration, + index_by_file_id=True, + context_file=context_file, + context_key=context_key, + answer_key=answer_key, + ) + + self.len = self._compute_len() + + self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) + self.trim = trim + + audio_tar_filepaths = expand_sharded_filepaths( + sharded_filepaths=audio_tar_filepaths, + shard_strategy=shard_strategy, + world_size=world_size, + global_rank=global_rank, + ) + + # Put together WebDataset + self._dataset = wds.WebDataset(urls=audio_tar_filepaths, nodesplitter=None) + + if shuffle_n == 0: + logging.info("WebDataset will not shuffle files within the tar files.") + + # Put together WebDataset pipeline + self._dataset = wds.DataPipeline( + wds.SimpleShardList(urls=audio_tar_filepaths), + webdataset_split_by_workers, + wds.shuffle(shuffle_n), + wds.tarfile_to_samples(), + wds.rename(audio=VALID_FILE_FORMATS, key='__key__'), + wds.to_tuple('audio', 'key'), + self._filter, + self._loop_offsets, + wds.map(self._build_sample), + ) + + def _filter(self, iterator): + """This function is used to remove samples that have been filtered out by ASRAudioText already. + Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample + that was filtered out (e.g. for duration). + Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard, + which may make your code hang as one process will finish before the other. + """ + return TarredAudioFilter(self.collection, iterator) + + def _loop_offsets(self, iterator): + """This function is used to iterate through utterances with different offsets for each file.""" + return TarredAudioLoopOffsets(self.collection, iterator) + + def _collate_fn(self, batch): + return _speechllm_audio_text_collate_fn( + batch=batch, + tokens_to_generate=self.tokens_to_generate, + pad_to_max_length=self.pad_to_max_length, + max_seq_length=self.max_seq_length, + text_pad_id=self.pad_id, + ) + + def collate_fn(self, batch): + # override collate_fn to skip type checking + return self._collate_fn(batch) + + def _build_sample(self, tup): + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" + audio_bytes, audio_filename, offset_id = tup + + if audio_filename is not None: + # Grab manifest entry from self.manifest_preprocessor.collection + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + manifest_idx = self.collection.mapping[file_id][offset_id] + manifest_entry = self.collection[manifest_idx] + + # init output dict + output = {"idx": manifest_idx} + + offset = manifest_entry.offset + if offset is None: + offset = 0 + # Convert audio bytes to IO stream for processing (for SoundFile to read) + audio_filestream = io.BytesIO(audio_bytes) + features = self.featurizer.process( + audio_filestream, + offset=offset, + duration=manifest_entry.duration, + trim=self.trim, + orig_sr=manifest_entry.orig_sr, + ) + audio_filestream.close() + + # Audio features + output["audio_signal"] = features + output["audio_length"] = torch.tensor(features.shape[0]).long() + else: + # dummy features + output["audio_signal"] = torch.zeros([80]) + # accomodates normalize_batch + output["audio_length"] = torch.tensor(80) + + # Text features + text_data = self._process_example(context=manifest_entry.context, output=manifest_entry.answer) + + output.update(text_data) + + output['metadata'] = { + 'audio_filepath': audio_filename, + 'offset': offset, + 'duration': manifest_entry.duration, + } + return output + + def get_manifest_sample(self, sample_id): + return self.collection[sample_id] + + def __iter__(self): + return self._dataset.__iter__() + + def _compute_len(self): + # TODO: need to figure out why here needs to be divided by world_size, while in ASR we don't need to. + if self.shard_manifests and torch.distributed.is_available() and torch.distributed.is_initialized(): + my_len = torch.tensor(len(self.collection), dtype=torch.int32).cuda() + torch.distributed.all_reduce(my_len) + my_len = my_len.int() // parallel_state.get_data_parallel_world_size() + logging.info(f'Sharded manifests: Total length: {my_len}') + else: + my_len = len(self.collection) // parallel_state.get_data_parallel_world_size() + + return my_len + + def __len__(self): + return self.len + + +def get_tarred_audio_text_dataset( + config, + tokenizer, + augmentor, + global_rank=0, + world_size=1, + shuffle_n=0, + sep_id=None, + answer_only_loss=True, + virtual_tokens=0, +): + tarred_audio_filepaths = config['tarred_audio_filepaths'] + manifest_filepaths = config['manifest_filepath'] + datasets = [] + tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths) + manifest_filepaths = convert_to_config_list(manifest_filepaths) + + bucketing_weights = config.get('bucketing_weights', None) # For upsampling buckets + if bucketing_weights: + for idx, weight in enumerate(bucketing_weights): + if not isinstance(weight, int) or weight <= 0: + raise ValueError(f"bucket weights must be positive integers") + + if len(manifest_filepaths) != len(tarred_audio_filepaths): + raise ValueError( + f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets." + ) + + if 'labels' not in config: + logging.warning(f"dataset does not have explicitly defined labels") + + if 'max_utts' in config: + raise ValueError('"max_utts" parameter is not supported for tarred datasets') + + for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate( + zip(tarred_audio_filepaths, manifest_filepaths) + ): + if len(tarred_audio_filepath) == 1: + tarred_audio_filepath = tarred_audio_filepath[0] + if len(manifest_filepath) == 1: + manifest_filepath = manifest_filepath[0] + + dataset = TarredAudioTextDataset( + audio_tar_filepaths=tarred_audio_filepath, + manifest_filepath=manifest_filepath, + tokenizer=tokenizer, + sample_rate=config['sample_rate'], + int_values=config.get('int_values', False), + augmentor=augmentor, + shuffle_n=shuffle_n, + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + trim=config.get('trim_silence', False), + shard_strategy=config.get('tarred_shard_strategy', 'scatter'), + shard_manifests=config.get('shard_manifests', False), + global_rank=global_rank, + world_size=world_size, + max_seq_length=config.max_seq_length, + min_seq_length=config.min_seq_length, + add_bos=config.get('add_bos', False), + add_eos=config.get('add_eos', True), + add_sep=config.get('add_sep', False), + sep_id=sep_id, + separate_prompt_and_response_with_newline=config.get('separate_prompt_and_response_with_newline', True), + answer_only_loss=answer_only_loss, + truncation_field=config.get('truncation_field', 'context'), + pad_to_max_length=False, + prompt_template=config.get('prompt_template', None), + virtual_tokens=virtual_tokens, + tokens_to_generate=config.get( + 'tokens_to_generate', 0 + ), # used at inference time to allocate tensor positions for tokens that will be generated by inf procedure. + context_key=config.get('context_key', 'context'), + answer_key=config.get('answer_key', 'answer'), + end_string=config.get('end_string', None), + sample_alpha=config.get('sample_alpha', None), + context_file=config.get('context_file', None), + ) + + if bucketing_weights: + [datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])] + else: + datasets.append(dataset) + + with open_dict(config): # patch for bucketing tarred datasets + config['batch_size'] = config.get("micro_batch_size", 1) + return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank) + + +def get_concat_tarred_audio_text_dataset( + config, + tokenizer, + augmentor, + global_rank=0, + world_size=1, + shuffle_n=0, + sep_id=None, + answer_only_loss=True, + virtual_tokens=0, +): + tarred_audio_filepaths = config['tarred_audio_filepaths'] + manifest_filepaths = config['manifest_filepath'] + datasets = [] + for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate( + zip(tarred_audio_filepaths, manifest_filepaths) + ): + conf = copy.deepcopy(config) + conf['manifest_filepath'] = manifest_filepath + conf['tarred_audio_filepaths'] = tarred_audio_filepath + context_files = config.get('context_file', None) + if isinstance(context_files, ListConfig) and len(context_files) == len(manifest_filepaths): + conf['context_file'] = context_files[dataset_idx] + else: + conf['context_file'] = context_files + dataset = get_tarred_audio_text_dataset( + config=conf, + tokenizer=tokenizer, + shuffle_n=shuffle_n, + global_rank=global_rank, + world_size=world_size, + augmentor=augmentor, + sep_id=sep_id, + answer_only_loss=answer_only_loss, + virtual_tokens=virtual_tokens, + ) + datasets.append(dataset) + + concat_sampling_probabilities = config.get('concat_sampling_probabilities', None) + if not isinstance(concat_sampling_probabilities, ListConfig) or len(concat_sampling_probabilities) != len( + datasets + ): + logging.info( + f"concat_sampling_probabilities is not provided or is not of the same size as datasets, using uniform sampling." + ) + concat_sampling_probabilities = [1.0 / len(datasets)] * len(datasets) + + dataset = ConcatDataset( + datasets, + sampling_technique=config.get('concat_sampling_technique', 'temperature'), + sampling_temperature=config.get('concat_sampling_temperature', 5), + sampling_scale=config.get('concat_sampling_scale', 1), + sampling_probabilities=concat_sampling_probabilities, + shuffle=config.get('concat_shuffle', True), + seed=config.get('concat_sampling_seed', None), + global_rank=global_rank, + world_size=world_size, + ) + return dataset + + +def get_tarred_audio_text_dataset_from_config( + config: DictConfig, + tokenizer, + augmentor, + global_rank: int = 0, + world_size: int = 1, + sep_id: Optional[int] = None, + answer_only_loss: bool = True, + virtual_tokens: int = 0, +): + is_concat = config.get('is_concat', False) + if is_concat: + if 'concat_sampling_technique' in config and config['concat_sampling_technique'] is None: + logging.warning( + f"Concat dataset requires `concat_sampling_technique` but it was not provided. Config: {config}" + ) + return None + + data_parallel_size = parallel_state.get_data_parallel_world_size() + num_micro_batches = config.global_batch_size // (config.micro_batch_size * data_parallel_size) + global_batch_size_on_this_data_parallel_rank = num_micro_batches * config.micro_batch_size + shuffle = config['shuffle'] + shuffle_n = config.get('shuffle_n', 4 * global_batch_size_on_this_data_parallel_rank) if shuffle else 0 + if is_concat: + dataset = get_concat_tarred_audio_text_dataset( + config=config, + tokenizer=tokenizer, + augmentor=augmentor, + shuffle_n=shuffle_n, + global_rank=global_rank, + world_size=world_size, + sep_id=sep_id, + answer_only_loss=answer_only_loss, + virtual_tokens=virtual_tokens, + ) + else: + dataset = get_tarred_audio_text_dataset( + config=config, + tokenizer=tokenizer, + augmentor=augmentor, + shuffle_n=shuffle_n, + global_rank=global_rank, + world_size=world_size, + sep_id=sep_id, + answer_only_loss=answer_only_loss, + virtual_tokens=virtual_tokens, + ) + return dataset + + +def get_audio_text_dataset_from_config( + manifest_filepath: str, + config: DictConfig, + tokenizer, + augmentor, + is_train, + sep_id: Optional[int] = None, + answer_only_loss: bool = True, + virtual_tokens: int = 0, +): + if isinstance(config.manifest_filepath, str): + manifest_filepath = config.manifest_filepath.split(',') + else: + manifest_filepath = config.manifest_filepath + + data_cls = MultiAudioTextDataset if config.get('audio_locator', None) else AudioTextDataset + datasets = [] + if is_train: + # Construct the data prefix list for `get_datasets_weights_and_num_samples()` + # that is of the format [weight1,file_name1,weight2,file_name2,...] + concat_sampling_probabilities = config.get('concat_sampling_probabilities', None) + if concat_sampling_probabilities is None: + concat_sampling_probabilities = [1.0 / len(manifest_filepath)] * len(manifest_filepath) + elif len(config.get('concat_sampling_probabilities', None)) != len(manifest_filepath): + raise ValueError( + ( + f"concat_sampling_probabilities must be of the same size as manifest_filepath.", + f"Provided size {len(config.concat_sampling_probabilities)}, number of datasets {len(manifest_filepath)}", + ) + ) + data_prefix = [] + for weight, prefix in zip(concat_sampling_probabilities, manifest_filepath): + data_prefix.append(weight) + data_prefix.append(prefix) + + num_samples_per_dataset = get_num_samples_from_files(manifest_filepath) + num_train_samples = [len(manifest_filepath) * max(num_samples_per_dataset)] + _, _, num_train_samples_per_dataset = get_datasets_weights_and_num_samples(data_prefix, num_train_samples) + num_train_samples_after_blend = sum([x[0] for x in num_train_samples_per_dataset]) + else: + num_train_samples_per_dataset = [[None]] * len(manifest_filepath) + + for dataset_idx, (file_path, num_samples) in enumerate(zip(manifest_filepath, num_train_samples_per_dataset)): + context_file = config.get('context_file', None) + if isinstance(context_file, ListConfig) and len(context_file) == len(manifest_filepath): + context_file = context_file[dataset_idx] + dataset = data_cls( + manifest_filepath=file_path, + tokenizer=tokenizer, + sample_rate=config.sample_rate, + int_values=config.get('int_values', False), + augmentor=augmentor, + max_duration=getattr(config, 'max_duration', None), + min_duration=getattr(config, 'min_duration', None), + max_utts=getattr(config, 'max_utts', -1), + trim=getattr(config, 'trim_silence', False), + channel_selector=getattr(config, 'channel_selector', None), + max_seq_length=config.max_seq_length, + min_seq_length=config.min_seq_length, + add_bos=config.get('add_bos', False), + add_eos=config.get('add_eos', True), + add_sep=config.get('add_sep', False), + sep_id=sep_id, + max_num_samples=num_samples[0], + seed=config.get('seed', 1234), + separate_prompt_and_response_with_newline=config.get('separate_prompt_and_response_with_newline', True), + answer_only_loss=answer_only_loss, + truncation_field=config.get('truncation_field', 'context'), + pad_to_max_length=config.get('pad_to_max_length', False), + prompt_template=config.get('prompt_template', None), + virtual_tokens=virtual_tokens, + tokens_to_generate=config.get( + 'tokens_to_generate', 0 + ), # used at inference time to allocate tensor positions for tokens that will be generated by inf procedure. + context_key=config.get('context_key', 'context'), + answer_key=config.get('answer_key', 'answer'), + end_string=config.get('end_string', None), + sample_alpha=config.get('sample_alpha', None), + context_file=context_file, + audio_locator=config.get('audio_locator', None), + ) + datasets.append(dataset) + + if is_train: + dataset = BlendableDataset( + datasets=datasets, weights=concat_sampling_probabilities, size=num_train_samples_after_blend + ) + return dataset + else: + return datasets diff --git a/nemo/collections/multimodal/speech_llm/models/__init__.py b/nemo/collections/multimodal/speech_llm/models/__init__.py new file mode 100644 index 000000000000..ec188828ec87 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py new file mode 100644 index 000000000000..39bc37c33e56 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -0,0 +1,1563 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import os +from typing import List, Optional, Union + +import hydra +import sacrebleu +import torch +from hydra.utils import get_class +from omegaconf import ListConfig +from omegaconf.dictconfig import DictConfig +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities import rank_zero_only + +from nemo.collections.asr.models import ASRModel, EncDecSpeakerLabelModel +from nemo.collections.asr.parts.mixins.transcription import move_to_device +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.asr.parts.utils.eval_utils import remove_punctuations +from nemo.collections.common.metrics import MetricStringToTorchMetric, TextMetricsSet +from nemo.collections.multimodal.speech_llm.data.audio_text_dataset import ( + get_audio_text_dataset_from_config, + get_tarred_audio_text_dataset_from_config, +) +from nemo.collections.multimodal.speech_llm.modules.common.audio_text_generation_utils import generate +from nemo.collections.multimodal.speech_llm.modules.perception_modules import ( + AudioPerceptionModule, + MultiAudioPerceptionModule, +) +from nemo.collections.multimodal.speech_llm.parts.mixins.adapter_mixin import SpeechLLMAdapterMixin +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import get_nested_dict_value +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( + MegatronPretrainingBatchSampler, +) +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + build_position_ids, +) +from nemo.collections.nlp.modules.common.text_generation_utils import get_computeprob_response +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.classes.mixins import adapter_mixins +from nemo.utils import AppState, logging +from nemo.utils.model_utils import inject_model_parallel_rank + +try: + from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator, get_num_microbatches + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.models.gpt import GPTModel as MCoreGPTModel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + HAVE_MEGATRON_CORE = False + + +__all__ = ["ModularAudioGPTModel"] + + +default_inference_config = {'tokens_to_generate': 30} + + +class ModularAudioGPTModel(SpeechLLMAdapterMixin, MegatronGPTSFTModel): + """Modularized speech GPT model.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer): + self.cfg = cfg + super().__init__(cfg, trainer) + + self.perception = ( + AudioPerceptionModule(cfg=cfg.perception) + if "encoders" not in cfg.perception + else MultiAudioPerceptionModule(cfg=cfg.perception) + ) + # print out params in more details + self.summarize(max_depth=2) + + def parameters(self): + # override the same method in MegatronGPT model to include parameters ouside of LM + all_names = [] + all_params = [] + for name, param in self.named_parameters(recurse=True): + all_names.append(name) + all_params.append(param) + + if isinstance(self.model, list): + for module in self.model: + for name, param in module.named_parameters(recurse=True): + all_names.append(name) + all_params.append(param) + + return itertools.chain(all_params) + + def setup_optimizer_param_groups(self): + """ + Override parent method to setup optimizer groups for training/freezing different parts of the model. + """ + known_groups = [] + if self.cfg.get('freeze_llm', True): + for param in self.model.parameters(): + param.requires_grad = False + known_groups.append('model.') + + if self.cfg.get('freeze_audio_encoder', False): + # freeze speaker model if there is any + if self.cfg.perception.get("speaker_model", None) is not None: + if self.cfg.perception.speaker_model.get("freeze", False): + self.perception.speaker_model.freeze() + known_groups.append('perception.speaker_model.') + # freeze other audio encoders + if self.cfg.perception.get("encoders", None) is not None: + # multiple audio encoders + for key, enc_cfg in self.cfg.perception.encoders.items(): + if enc_cfg.get("freeze", False): + self.perception.encoders[key].freeze() + known_groups.append(f'perception.encoders.{key}.') + else: + # single audio encoder + self.perception.encoder.freeze() + known_groups.append('perception.encoder.') + + if self.cfg.get('freeze_modality_adapter', False): + # freeze modality adapter + self.perception.modality_adapter.freeze() + known_groups.append('perception.modality_adapter.') + + opt_params = [] + for _, module in self.named_modules(): + if isinstance(module, adapter_mixins.AdapterModuleMixin) and module.is_adapter_available(): + # add adapters to the optimizer + module.set_enabled_adapters(enabled=True) + module.unfreeze_enabled_adapters() # selectively unfreeze the adapter modules. + opt_params += [p for p in module.parameters()] + + # add param groups with specified args, if any + param_groups = [] + if "optim_param_groups" in self.cfg: + param_groups_cfg = self.cfg.optim_param_groups + for group, group_cfg in param_groups_cfg.items(): + module = getattr(self, group, None) + if module is None: + raise ValueError(f"{group} not found in model.") + elif hasattr(module, "parameters"): + known_groups.append(f"{group}.") + new_group = {"params": module.parameters()} + for k, v in group_cfg.items(): + new_group[k] = v + param_groups.append(new_group) + else: + raise ValueError(f"{group} does not have parameters.") + + # add other trainable params + for n, p in self.named_parameters(): + is_unknown = True + for group in known_groups: + if n.startswith(group): + is_unknown = False + if is_unknown: + opt_params.append(p) + + param_groups = [{"params": opt_params}] + param_groups + + self._optimizer_param_groups = param_groups + logging.info(f"Optimizer groups set:\n{self.summarize(max_depth=2)}") + + def _create_attention_mask(self, encoder_input: torch.Tensor): + # Create causal attention mask for whole input + batch_size = encoder_input.shape[0] + max_len = encoder_input.shape[1] + attention_mask = torch.tril(torch.ones((batch_size, max_len, max_len), device=encoder_input.device)).view( + batch_size, 1, max_len, max_len + ) + # Convert attention mask from float to bool + attention_mask = attention_mask < 0.5 + return attention_mask + + def _concat_features(self, embs1, emb1_lens, embs2, emb2_lens): + """Concatenate two sets of embeddings and their lengths.""" + concat_emb = [] + concat_len = [] + for emb1, emb1_len, emb2, emb2_len in zip(embs1, emb1_lens, embs2, emb2_lens): + new_len = emb1_len + emb2_len + new_emb = torch.concat([emb1[:emb1_len], emb2[:emb2_len]], axis=0) + padded_new_emb = torch.zeros(emb1.shape[0] + emb2.shape[0], emb1.shape[-1], device=emb1.device) + padded_new_emb[:new_len, ...] = new_emb + concat_emb.append(padded_new_emb) + concat_len.append(new_len) + concat_emb = torch.stack(concat_emb, dim=0) + concat_len = torch.stack(concat_len, dim=0) + return concat_emb, concat_len + + def _concat_multi_features( + self, + encoded: List[torch.Tensor], + encoded_len: List[torch.Tensor], + input_embeds: torch.Tensor, + input_length: torch.Tensor, + context_start_idx: List[List[int]], + ): + """Concatenate multiple audio features with text segments.""" + encoder_input_list, encoder_length_list = [], [] + batch_size = input_embeds.size(0) + max_length = 0 + for i in range(batch_size): + start_idx_list_i = context_start_idx[i] + [ + input_embeds.size(1) + ] # use input_embeds instead of input_length to handle tokens_to_generate in inference + input_len_list = [start_idx_list_i[j + 1] - start_idx_list_i[j] for j in range(len(start_idx_list_i) - 1)] + input_emb_list = input_embeds[i].split(input_len_list) + encoder_input_i = [input_emb_list[0]] + for j in range(1, len(input_emb_list)): + encoder_input_i.append(encoded[i][j - 1][: encoded_len[i][j - 1]]) + encoder_input_i.append(input_emb_list[j]) + encoder_input_i = torch.cat(encoder_input_i) # T, C + encoder_length_i = encoded_len[i].sum() + input_length[i] # total length of audio and text features + max_length = max(max_length, encoder_input_i.size(0)) + encoder_input_list.append(encoder_input_i) + encoder_length_list.append(encoder_length_i) + + encoder_input = torch.stack( + [torch.nn.functional.pad(f, (0, 0, 0, max_length - f.size(0))) for f in encoder_input_list] + ) + encoder_length = torch.LongTensor(encoder_length_list).to(encoder_input.device) + return encoder_input, encoder_length + + def inject_perception_input( + self, + encoded: Union[torch.Tensor, List[torch.Tensor]], + encoded_len: Union[torch.Tensor, List[torch.Tensor]], + input_ids: torch.Tensor, + input_length: torch.Tensor, + context_start_idx: Optional[List[List[int]]] = None, + ): + """Inject audio features into the text input and return the final input embeddings to LLM.""" + # [b, t, c] + lm_embedding = ( + self.model.language_model.embedding if hasattr(self.model, 'language_model') else self.model.embedding + ) + input_embeds = lm_embedding.word_embeddings(input_ids) + if isinstance(encoded, torch.Tensor): + # single audio + encoder_input, encoder_length = self._concat_features(encoded, encoded_len, input_embeds, input_length) + else: + # concat multiple audios with text segments + encoder_input, encoder_length = self._concat_multi_features( + encoded, encoded_len, input_embeds, input_length, context_start_idx + ) + + attention_mask = self._create_attention_mask(encoder_input) + position_ids = build_position_ids(encoder_input[:, :, 0]) + + # Add position embeddings + if ( + getattr(lm_embedding, "position_embeddings", None) is not None + and lm_embedding.position_embedding_type == 'learned_absolute' + ): + position_embeddings = lm_embedding.position_embeddings(position_ids) + encoder_input = encoder_input + position_embeddings + + encoder_max_length = encoder_input.shape[1] + if not hasattr(lm_embedding, 'transpose_batch_sequence') or lm_embedding.transpose_batch_sequence: + encoder_input = encoder_input.transpose(0, 1).contiguous() + if self.cfg.get("sequence_parallel", False): + encoder_input = tensor_parallel.mappings.scatter_to_sequence_parallel_region(encoder_input) + return encoder_input, attention_mask, encoder_length, position_ids, encoder_max_length + + def _shift_labels_by_emb_len(self, labels, label_lens, emb_lens, max_len, pad_token=0): + """Shift labels to the right by the length of the audio embeddings.""" + shifted_labels = [] + for label, label_len, emb_len in zip(labels, label_lens, emb_lens): + shifted_label = torch.full([max_len], pad_token, device=label.device) + shifted_label[emb_len : emb_len + label_len] = label[:label_len] + shifted_labels.append(shifted_label) + shifted_labels = torch.stack(shifted_labels, dim=0) + return shifted_labels + + def _get_text_embeddings(self, text_tokens, position_ids): + """Get text embeddings for the input text tokens.""" + lm_embedding = ( + self.model.language_model.embedding if hasattr(self.model, 'language_model') else self.model.embedding + ) + text_embeddings = lm_embedding.word_embeddings(text_tokens) # (batch_size, seq_len, hidden_size) + if hasattr(lm_embedding, 'position_embeddings'): + position_embeddings = lm_embedding.position_embeddings(position_ids) + text_embeddings = text_embeddings + position_embeddings + return text_embeddings.transpose(0, 1) + + def prepare_llm_input(self, audio_batch): + """Prepare input for the LLM.""" + input_signal = audio_batch['audio_signal'] + input_signal_length = audio_batch['audio_signal_length'] + + input_ids, input_length, labels, loss_mask = ( + audio_batch['tokens'], + audio_batch['tokens_length'], + audio_batch['labels'], + audio_batch['loss_mask'], + ) + + num_audios = audio_batch.get("num_audios", None) + context_start_idx = audio_batch.get("context_start_idx", None) + + # [b, t, c] + encoded, encoded_len = self.perception( + input_signal=input_signal, + input_signal_length=input_signal_length, + processed_signal=None, + processed_signal_length=None, + ) + + if num_audios is not None: + # split the encoded and encoded_len by num_audios, used when there're multiple audio files per sample + encoded = encoded.split(num_audios.tolist()) + encoded_len = encoded_len.split(num_audios.tolist()) + + encoder_input, attention_mask, encoder_length, _, encoder_max_length = self.inject_perception_input( + encoded, encoded_len, input_ids, input_length, context_start_idx + ) + if num_audios is not None: + # sum up the audio_feat_lens for each sample in the batch + encoded_len = torch.stack([torch.sum(lens) for lens in encoded_len]) + + # Shift labels to the right + labels = self._shift_labels_by_emb_len(labels, input_length, encoded_len, encoder_max_length, pad_token=0) + # Loss mask where answer tokens are 1.0 and all other tokens are 0.0 + loss_mask = self._shift_labels_by_emb_len( + loss_mask, input_length, encoded_len, encoder_max_length, pad_token=0 + ) + + return encoder_input, attention_mask, labels, loss_mask, encoder_length + + def forward( + self, + audio_batch, + checkpoint_activations_all_layers, + ): + """ + Forward pass of the model. We prepend audio embeddings to the instruction and label text tokens as the LLM input. + """ + encoder_input, attention_mask, labels, loss_mask, _ = self.prepare_llm_input(audio_batch) + if self.mcore_gpt: + output = self.model( + input_ids=None, + position_ids=None, + decoder_input=encoder_input, + attention_mask=attention_mask, + labels=labels, + ) + else: + output = self.model( + input_ids=None, + position_ids=None, + encoder_input=encoder_input, + attention_mask=attention_mask, + labels=labels, + checkpoint_activations_all_layers=checkpoint_activations_all_layers, + ) + + return output, loss_mask + + def get_forward_output_only_func(self): + def fwd_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) + extra_arg = {} + # take the batch produced by prepare_batch_at_step + ( + tokens, + input_embeddings, + attention_mask, + position_ids, + set_inference_key_value_memory, + inference_max_sequence_len, + ) = batch + tokens = tokens.cuda() + + if attention_mask is not None: + attention_mask = attention_mask.cuda() + attention_mask = attention_mask[0:1] + if self.mcore_gpt: + # if first step, then clear KV cache, otherwise reuse inference_paarms + if set_inference_key_value_memory[0].item(): + self.inference_params = InferenceParams( + max_batch_size=tokens.size(0), max_sequence_length=inference_max_sequence_len[0].item() + ) + extra_arg['inference_params'] = self.inference_params + else: + extra_arg['set_inference_key_value_memory'] = set_inference_key_value_memory[0].item() + extra_arg['inference_max_sequence_len'] = inference_max_sequence_len[0].item() + + # Currently for all MCore transformer layer specs causal attention mask + # is used so we can delegate creating it to MCore/TE and pass None below + if ( + isinstance(model, MCoreGPTModel) + or hasattr(model, "module") + and isinstance(model.module, MCoreGPTModel) + ): + attention_mask = None + + output_tensor = model( + input_ids=None, + position_ids=None, + decoder_input=input_embeddings, + attention_mask=attention_mask, + **extra_arg, + ) + + # Advance inference sequence offset. + if self.inference_params: + # if last stage, then (final) output is [b, s, h], otherwise it's [s, b, h] + if parallel_state.is_pipeline_last_stage(): + self.inference_params.sequence_len_offset += output_tensor.size(1) + else: + self.inference_params.sequence_len_offset += output_tensor.size(0) + + def id_func(output_tensor): + return output_tensor, {'logits': output_tensor} + + return output_tensor, id_func + + return fwd_output_only_func + + def get_forward_output_and_loss_func(self, validation_step=False, tuning=False): + def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): + batch = next(dataloader_iter) + + # Transfer needed data to GPU + required_keys = set() + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + required_keys.add('attention_mask') + if parallel_state.is_pipeline_first_stage(): + required_keys.update(('tokens', 'position_ids')) + if parallel_state.is_pipeline_last_stage(): + required_keys.update(('labels', 'loss_mask')) + if self.get_attention_mask_from_fusion and 'attention_mask' in required_keys: + required_keys.remove('attention_mask') + + batch = move_to_device(batch, self.device) + batch = self.get_batch_on_this_context_parallel_rank(batch) + + if not self.mcore_gpt: + batch['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers + + output_tensor, loss_mask = self.forward( + batch, checkpoint_activations_all_layers=checkpoint_activations_all_layers + ) + batch['loss_mask'] = loss_mask + + def loss_func(output_tensor): + # Loss for a micro-batch (ub) + loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) + cp_size = self.cfg.get('context_parallel_size', 1) + if self.cfg.data.get( + "return_output_tensors", False + ): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare) + loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + pos_cs = average_losses_across_data_parallel_group([pos_cs]) + neg_cs = average_losses_across_data_parallel_group([neg_cs]) + diff_cs = average_losses_across_data_parallel_group([diff_cs]) + return ( + loss_for_ub * cp_size, + { + 'avg': reduced_loss, + 'query_hs': q_hs, + 'doc_hs': d_hs, + 'avg_pos_cs': pos_cs, + 'avg_neg_cs': neg_cs, + 'diff_cs': diff_cs, + }, + ) + elif validation_step and not self.cfg.data.get('validation_drop_last', True): + num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] + if loss_for_ub.isnan(): + assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' + loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub) + else: + loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub + + loss_sum_and_ub_size_all_gpu = torch.cat( + [ + loss_sum_for_ub.clone().detach().view(1), + torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(), + ] + ) + # Could potentially reduce num_valid_samples_in_microbatch and use that to aggregate instead of len(self._validation_ds) + torch.distributed.all_reduce( + loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() + ) + return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} + else: + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + return loss_for_ub * cp_size, {'avg': reduced_loss} + + return output_tensor, loss_func + + return fwd_output_and_loss_func + + def _build_dataset(self, data_cfg, is_train=True): + if 'augmentor' in data_cfg: + augmentor = process_augmentations( + data_cfg['augmentor'], global_rank=self.global_rank, world_size=self.world_size + ) + else: + augmentor = None + + # Check dataset max_seq_legnth and max_position_embeddings size + if ( + self.cfg.get('position_embedding_type', None) in [None, 'learned_absolute'] + and data_cfg.max_seq_length > self.cfg.max_position_embeddings + ): + logging.warning( + f"Set dataset max_seq_length to max_position_embeddings {self.cfg.max_position_embeddings} if using learned_absolute position embedding" + ) + data_cfg.max_seq_length = self.cfg.max_position_embeddings + + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type. + if data_cfg.get('is_tarred', False): + return get_tarred_audio_text_dataset_from_config( + config=data_cfg, + tokenizer=self.tokenizer, + augmentor=augmentor, + sep_id=self.sep_id, + answer_only_loss=self.cfg.get('answer_only_loss', True), + virtual_tokens=self.virtual_tokens, + global_rank=parallel_state.get_data_parallel_rank(), + world_size=parallel_state.get_data_parallel_world_size(), + ) + else: + return get_audio_text_dataset_from_config( + manifest_filepath=data_cfg.manifest_filepath, + config=data_cfg, + tokenizer=self.tokenizer, + augmentor=augmentor, + is_train=is_train, + sep_id=self.sep_id, + answer_only_loss=self.cfg.get('answer_only_loss', True), + virtual_tokens=self.virtual_tokens, + ) + + def build_data_loader(self, dataset, data_cfg, consumed_samples=0, is_predict=False): + """Buld dataloader given an input dataset.""" + logging.info(f'Building dataloader with consumed samples: {consumed_samples}') + if isinstance(dataset, BlendableDataset): + collate_fn = dataset.datasets[0].collate_fn + elif hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + # support datasets that are lists of entries + collate_fn = dataset.datasets[0].collate_fn + else: + # support datasets that are lists of lists + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + if isinstance(dataset, torch.utils.data.IterableDataset): + data_parallel_size = parallel_state.get_data_parallel_world_size() + num_micro_batches = data_cfg.global_batch_size // (data_cfg.micro_batch_size * data_parallel_size) + global_batch_size_on_this_data_parallel_rank = num_micro_batches * data_cfg.micro_batch_size + + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=collate_fn, + shuffle=False, + batch_size=global_batch_size_on_this_data_parallel_rank, + drop_last=True, + num_workers=data_cfg.num_workers, + pin_memory=data_cfg.pin_memory, + ) + return dataloader + + if is_predict: + # MegatronPretrainingBatchSampler doesn't work with trainer.predict() + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=collate_fn, + batch_size=data_cfg.micro_batch_size, + num_workers=data_cfg.num_workers, + pin_memory=data_cfg.pin_memory, + ) + return dataloader + + batch_sampler = MegatronPretrainingBatchSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=data_cfg.micro_batch_size, + global_batch_size=data_cfg.global_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=data_cfg.drop_last, + pad_samples_to_global_batch_size=not data_cfg.drop_last, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=data_cfg.num_workers, + pin_memory=data_cfg.pin_memory, + persistent_workers=True if data_cfg.num_workers > 0 else False, + ) + return dataloader + + @classmethod + def _modify_audio_encoder_config(cls, gpt_cfg, audio_cfg, speaker_cfg=None): + """load the ecoder configs from the pretrained audio models and updating the model's config.""" + with open_dict(gpt_cfg): + use_multi_encoder = gpt_cfg.perception.get("encoders", None) is not None + if not use_multi_encoder: + gpt_cfg.perception.preprocessor = audio_cfg.preprocessor + gpt_cfg.perception.encoder = audio_cfg.encoder + else: + for key in gpt_cfg.perception.encoders: + model_key = gpt_cfg.perception.encoders[key].get("model_key", "encoder") + gpt_cfg.perception.encoders[key]["model"] = audio_cfg[key][model_key] + if "preprocessor" in audio_cfg[key]: + gpt_cfg.perception.encoders[key]['preprocessor'] = audio_cfg[key].preprocessor + if speaker_cfg is not None: + gpt_cfg.perception.speaker_model.model = speaker_cfg + + gpt_cfg.perception.output_dim = gpt_cfg.hidden_size + modality_adapter_cfg = gpt_cfg.perception.modality_adapter + if 'output_dim' in modality_adapter_cfg: + modality_adapter_cfg.output_dim = gpt_cfg.hidden_size + if not use_multi_encoder: + model_dim_key = gpt_cfg.perception.get("model_dim_key", "d_model") + encoder_dim = get_nested_dict_value(audio_cfg.encoder, model_dim_key) + input_dim = encoder_dim + if ( + gpt_cfg.perception.get('use_multi_layer_feat', False) + and gpt_cfg.perception.multi_layer_feat.aggregator.get("mode", "cat") == "cat" + ): + input_dim = encoder_dim * len(gpt_cfg.perception.multi_layer_feat.layer_idx_list) + else: + input_dim = 0 + if speaker_cfg is not None: + input_dim += speaker_cfg.decoder.emb_sizes + for enc_cfg in gpt_cfg.perception.encoders.values(): + encoder_dim = get_nested_dict_value(enc_cfg.model, enc_cfg.get("model_dim_key", "d_model")) + if ( + enc_cfg.get('use_multi_layer_feat', False) + and enc_cfg.multi_layer_feat.aggregator.get("mode", "cat") == "cat" + ): + input_dim += encoder_dim * len(enc_cfg.multi_layer_feat.layer_idx_list) + else: + input_dim += encoder_dim + + if 'feat_in' in modality_adapter_cfg: + modality_adapter_cfg.feat_in = input_dim + elif 'input_dim' in modality_adapter_cfg: + modality_adapter_cfg.input_dim = input_dim + + @classmethod + def _modify_config(cls, gpt_cfg, cfg, audio_cfg, add_cfg_to_tree=False, speaker_cfg=None): + """ + This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg). + The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. + """ + OmegaConf.set_struct(gpt_cfg, True) + OmegaConf.resolve(cfg) + with open_dict(gpt_cfg): + # for AudioGPTLoRAModel + gpt_cfg.target = f"{cls.__module__}.{cls.__name__}" + gpt_cfg.perception = cfg.model.perception + # inject audio encoder configs into the target config (gpt_cfg) + cls._modify_audio_encoder_config(gpt_cfg, audio_cfg, speaker_cfg) + + # inject the sample rate from the audio encoder into the gpt config + if isinstance(audio_cfg, (ListConfig, list)): + sample_rate = [_cfg.preprocessor.sample_rate for _cfg in audio_cfg] + if not all([sr == sample_rate[0] for sr in sample_rate]): + raise ValueError("All audio encoders must have the same sample rate.") + gpt_cfg.data.train_ds.sample_rate = sample_rate[0] + gpt_cfg.data.validation_ds.sample_rate = sample_rate[0] + else: + sample_rate = audio_cfg.preprocessor.sample_rate + gpt_cfg.data.train_ds.sample_rate = sample_rate + gpt_cfg.data.validation_ds.sample_rate = sample_rate + + # This is needed when modifying a hparam file directly to load `.ckpt` files. + # This is not needed to modify the cfg in `.nemo` files. + if add_cfg_to_tree: + OmegaConf.resolve(gpt_cfg) + gpt_cfg.cfg = gpt_cfg + + return gpt_cfg + + @classmethod + def get_pretraind_audio_model(cls, encoder_cfg: DictConfig) -> ModelPT: + """load pretrained audio model from a given config""" + if encoder_cfg.get("_target_", None) is not None: + encoder_cls = get_class(encoder_cfg.get("_target_")) + elif encoder_cfg.get("target", None) is not None: + encoder_cls = get_class(encoder_cfg.get("target")) + else: + encoder_cls = ASRModel + + pretrained_model = encoder_cfg.get('pretrained_model', None) + if pretrained_model is None: + return None + if encoder_cls is None: + raise ValueError( + f"Must specify a valid encoder class in the via the `_target_` field in the config: {encoder_cfg}" + ) + + if pretrained_model.endswith('.nemo'): + logging.info(f'Loading pretrained audio model from local file: {pretrained_model}') + audio_model = encoder_cls.restore_from(pretrained_model, map_location='cpu') + else: + logging.info(f'Loading pretrained audio model from NGC: {pretrained_model}') + audio_model = encoder_cls.from_pretrained(pretrained_model, map_location='cpu') + return audio_model + + @classmethod + def get_speaker_model_and_config(cls, cfg): + """load speaker embedding model and config if present in the config.""" + if 'speaker_model' in cfg.model.perception: + if cfg.model.get("_target_", None) is not None: + model_cls = get_class(cfg.model.get("_target_")) + elif cfg.model.get("target", None) is not None: + model_cls = get_class(cfg.model.get("target")) + else: + model_cls = EncDecSpeakerLabelModel + + speaker_cfg = cfg.model.perception.speaker_model + if speaker_cfg.get('pretrained_model', None) is not None: + if speaker_cfg.pretrained_model.endswith('.nemo'): + logging.info(f'Loading pretrained speaker model from local file: {speaker_cfg.pretrained_model}') + speaker_model = model_cls.restore_from(speaker_cfg.pretrained_model, map_location='cpu') + else: + logging.info(f'Loading pretrained speaker model from NGC: {speaker_cfg.pretrained_model}') + speaker_model = model_cls.from_pretrained(speaker_cfg.pretrained_model, map_location='cpu') + return speaker_model, speaker_model.cfg + return None, None + else: + return None, None + + @classmethod + def get_audio_encoder_models_and_configs(cls, cfg): + if 'encoders' in cfg.model.perception: + audio_encoders = {} + audio_enc_cfgs = {} + for key, encoder_cfg in cfg.model.perception.encoders.items(): + audio_encoders[key] = cls.get_pretraind_audio_model(encoder_cfg) + audio_enc_cfgs[key] = audio_encoders[key].cfg + return audio_encoders, audio_enc_cfgs + else: + pretrained_audio_model = cfg.model.get("pretrained_audio_model", None) + pretrained_audio_model_class = cfg.model.get( + "pretrained_audio_model_target", "nemo.collections.asr.models.ASRModel" + ) + + model_class = hydra.utils.get_class(pretrained_audio_model_class) + if pretrained_audio_model.endswith('.nemo'): + logging.info(f'Loading pretrained audio model from local file: {pretrained_audio_model}') + audio_model = model_class.restore_from(pretrained_audio_model, map_location='cpu') + else: + logging.info(f'Loading pretrained audio model from NGC: {pretrained_audio_model}') + audio_model = model_class.from_pretrained(pretrained_audio_model, map_location='cpu') + return audio_model, audio_model.cfg + + @classmethod + def load_pretrained_audio_weights( + cls, cfg, model, audio_model, speaker_model: Optional[EncDecSpeakerLabelModel] = None + ): + use_multi_encoder = cfg.model.perception.get("encoders", None) is not None + if not use_multi_encoder: + if cfg.model.perception.get("use_multi_layer_feat", False): + model.perception.encoder.encoder.load_state_dict(audio_model.encoder.state_dict(), strict=True) + else: + model.perception.encoder.load_state_dict(audio_model.encoder.state_dict(), strict=True) + logging.info(f'Loaded pretrained audio model weights from {cfg.model.pretrained_audio_model}') + if cfg.model.get('use_am_tokenizer', False): + model.tokenizer = audio_model.tokenizer + logging.info(f'Use AM tokenizer: {audio_model.tokenizer}') + return model + else: + for key, enc_cfg in cfg.model.perception.encoders.items(): + if enc_cfg.get("use_multi_layer_feat", False): + model.perception.encoders[key].encoder.load_state_dict( + audio_model[key].encoder.state_dict(), strict=True + ) + else: + model.perception.encoders[key].load_state_dict(audio_model[key].encoder.state_dict(), strict=True) + logging.info(f'Loaded pretrained audio model weights for {key}') + if speaker_model is not None: + model.perception.speaker_model.load_state_dict(speaker_model.state_dict(), strict=True) + logging.info(f'Loaded pretrained speaker model weights') + return model + + @classmethod + def restore_from_pretrained_models( + cls, + cfg: Optional[Union[OmegaConf, str]] = None, + trainer: Optional[Trainer] = None, + ): + """ + load pretrained LLM and audio encoders, and maybe add adapters, used for training. + Args: + cfg: input yaml config, with trainer, model, exp_manager, etc. + trainer: trainer object + """ + if ( + cfg.model.get("pretrained_audio_model", None) is None + and cfg.model.perception.get("encoders", None) is None + ): + raise RuntimeError("PEFT training needs at least one pretrained audio model present.") + + if not cfg.model.restore_from_path: + raise RuntimeError("PEFT training needs a trained base model present.") + + base_model_cfg = MegatronGPTSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg) + audio_model, audio_model_cfg = cls.get_audio_encoder_models_and_configs(cfg) + speaker_model, speaker_cfg = cls.get_speaker_model_and_config(cfg) + model_cfg = cls._modify_config( + base_model_cfg, cfg, audio_model_cfg, add_cfg_to_tree=False, speaker_cfg=speaker_cfg + ) + + # load llm + model = cls.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + override_config_path=model_cfg, + strict=False, + map_location="cpu", + ) + + if "peft" in cfg.model: + peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] + if cfg.model.peft.restore_from_path is not None: + # initialize peft weights from a checkpoint instead of randomly + # This is not the same as resume training because optimizer states are not restored. + logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path) + model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg), map_location="cpu") + elif peft_cfg_cls is not None: + logging.info("Adding adapter weights to the model for PEFT") + model.add_adapter(peft_cfg_cls(model_cfg)) + else: + raise ValueError(f"PEFT scheme not not found in PEFT_CONFIG_MAP: {cfg.model.peft.peft_scheme}") + else: + logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}") + + # load audio model weights + model = cls.load_pretrained_audio_weights(cfg, model, audio_model, speaker_model) + + if 'inference' in cfg: + inference_cfg = OmegaConf.to_container(cfg.inference, resolve=True) + model.set_inference_config(inference_cfg) + return model + + @classmethod + def load_audio_encoder_for_inference(cls, cfg: DictConfig, model_cfg: DictConfig, model: ModelPT) -> ModelPT: + """ + Maybe load audio encoders for inference, if they were not tunable during training. + Args: + cfg: inference config + model_cfg: model config + model: model object + Returns: + model: model object with audio encoder weights loaded + """ + if model_cfg.freeze_audio_encoder and model_cfg.get("pretrained_audio_model", None) is not None: + with open_dict(cfg): + cfg.model.perception = model_cfg.perception + + audio_model, _ = cls.get_audio_encoder_models_and_configs(cfg) + speaker_model, _ = cls.get_speaker_model_and_config(cfg) + model = cls.load_pretrained_audio_weights(cfg, model, audio_model, speaker_model) + return model + + @classmethod + def merge_inference_cfg( + cls, cfg: DictConfig, trainer: Trainer, pretrained_model_cfg: DictConfig = None + ) -> DictConfig: + """ + Merge the inference config with the model config, used for inference only. + if no pretrained_model_cfg is given, it will be loaded from the checkpoint specified in cfg. + Args: + cfg: inference config + trainer: trainer object + pretrained_model_cfg: a pre-loaded SpeechLLM model config + Returns: + model_cfg: merged model config + """ + if pretrained_model_cfg: + model_cfg = pretrained_model_cfg + elif cfg.model.peft.restore_from_path: + if cfg.model.peft.restore_from_path.endswith(".nemo"): + model_cfg = ModularAudioGPTModel.restore_from( + restore_path=cfg.model.peft.restore_from_path, + trainer=trainer, + return_config=True, + ) + elif cfg.model.peft.restore_from_hparams_path: # not a .nemo model we expect a hparams.yaml file + model_cfg = OmegaConf.to_container(OmegaConf.load(cfg.model.peft.restore_from_hparams_path).cfg) + model_cfg = OmegaConf.create(model_cfg) + # extract dict inside cfg key and convert it to DictConfig + # this allows interpolation to work the same way as config from the .restore_from method + else: + raise RuntimeError( + "This script requires a .nemo peft model or path to hparams.yaml (and a ckpt path)." + ) + else: + model_cfg = MegatronGPTSFTModel.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + return_config=True, + ) + + if hasattr(model_cfg, 'peft') and model_cfg.peft.peft_scheme not in [None, 'none']: + # before PEFT migrates to distributed ckpt, eval must use same TP/PP as training + for p in ['tensor_model_parallel_size', 'pipeline_model_parallel_size']: + assert model_cfg.get(p) == cfg.model.get( + p + ), f"PEFT evaluation {p} ({cfg.model.get(p)}) must equal training {p} ({model_cfg.get(p)})" + + with open_dict(model_cfg): + # to be compatible with old checkpoints + if "context_key" not in model_cfg.data.train_ds or "answer_key" not in model_cfg.data.train_ds: + model_cfg.data.train_ds.context_key = "question" + model_cfg.data.train_ds.answer_key = "answer" + + # update the model config of the trained model with params we want to set at inference time. + model_cfg.precision = cfg.trainer.precision + for key, val in cfg.model.items(): + if key != 'data' and key != 'peft': + model_cfg[key] = val + model_cfg.data.test_ds = cfg.model.data.test_ds + + with open_dict(cfg): + if model_cfg.data.test_ds is not None: + cfg.inference.add_BOS = model_cfg.data.test_ds.get("add_BOS", False) + cfg.inference.tokens_to_generate = model_cfg.data.test_ds.get("tokens_to_generate", 1) + + model_cfg.megatron_amp_O2 = False # always evaluate with O1 + return model_cfg + + @classmethod + def load_adapters_for_inference(cls, cfg: DictConfig, model_cfg: DictConfig, model: ModelPT) -> ModelPT: + if cfg.model.peft.restore_from_path: + if '\\' in cfg.model.peft.restore_from_path: + cfg.model.peft.restore_from_path = cfg.model.peft.restore_from_path.replace('\\', '') + if "peft" in model_cfg: + peft_cfg_cls = PEFT_CONFIG_MAP[model_cfg.peft.peft_scheme] + model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg), map_location="cpu") + else: + model.load_state_dict(torch.load(cfg.model.peft.restore_from_path), strict=False) + elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name: + checkpoint_path = os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank( + os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, + cfg.model.peft.restore_from_ckpt.checkpoint_name, + ) + ) + if "peft" in model_cfg: + peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] + model.load_adapters(checkpoint_path, peft_cfgs=peft_cfg_cls(model_cfg), map_location="cpu") + else: + model.load_state_dict(torch.load(checkpoint_path), strict=False) + else: + raise NotImplementedError("distributed checkpointing of PEFT weights is not supported") + elif model_cfg.peft.get("peft_scheme", None): + # special case for loading a complete speechllm checkpoint in nemo format + peft_cfg_cls = PEFT_CONFIG_MAP[model_cfg.peft.peft_scheme] + model.load_adapters(cfg.model.restore_from_path, peft_cfg_cls(model_cfg), map_location="cpu") + return model + + def _build_vocab(self): + """ + Manipulate vocabulary (e.g., pad vocabulary for increased performance)/ + """ + if self._cfg.get('override_vocab_size', None) is not None: + self.padded_vocab_size = self._cfg.override_vocab_size + else: + self.padded_vocab_size = self._vocab_size_with_padding( + orig_vocab_size=self.tokenizer.vocab_size, + make_vocab_size_divisible_by=self._cfg.get('make_vocab_size_divisible_by', 128), + tensor_model_parallel_size=self._cfg.get('tensor_model_parallel_size', 1), + ) + + def state_dict(self, destination=None, prefix=None, keep_vars=False): + """ + Overwrite the state_dict method to include only the trainable parameters. + """ + if self.setup_complete and self.trainer.state.fn == "fit": + # Once setup is complete we only need adapter and perception model. + if self.cfg.freeze_llm and self.cfg.get("peft", None) is not None: + return_state_dict = self.get_peft_state_dict() + elif not self.cfg.freeze_llm: + return_state_dict = self.model.state_dict(prefix="model.") + else: + return_state_dict = {} + + state_dict = self.perception.state_dict(prefix="perception.") + if self.cfg.freeze_audio_encoder: + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("perception.encoder.")} + + return_state_dict.update(state_dict) + state_dict = self.perception.state_dict(prefix="perception.") + return_state_dict.update(state_dict) + return return_state_dict + elif self.setup_complete and self.trainer.state.fn != "fit": + # used to save the whole model as a nemo file + return_state_dict = self.model.state_dict(prefix="model.") + state_dict = self.perception.state_dict(prefix="perception.") + return_state_dict.update(state_dict) + return return_state_dict + else: + # we want all the params with the same keys as calling self.state_dict() + # but we can't call self.state_dict() here as it would be a recursive call. + # so we call self.model.state_dict(prefix="model.") which will return all the keys and params same as calling self.state_dict() + if not self.cfg.freeze_llm: + return_state_dict = self.model.state_dict(prefix="model.") + else: + return_state_dict = {} + state_dict = self.perception.state_dict(prefix="perception.") + if self.cfg.freeze_audio_encoder: + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("perception.encoder.")} + return_state_dict.update(state_dict) + return return_state_dict + + def load_state_dict(self, state_dict, strict: bool = True): + if not self.setup_complete: + if self.cfg.get('override_vocab_size', False): + exclude_list = [ + "model.language_model.embedding.word_embeddings.weight", + "model.language_model.output_layer.weight", + ] + else: + exclude_list = [] + state_dict = {k: v for k, v in state_dict.items() if k not in exclude_list} + else: + strict = False + + if len(state_dict) == 0: + return # checkpoint is loaded in on_load_checkpoint() + if self.use_peft and self.setup_complete: + # at this stage only adapter params will appear in the state_dict arg + # so we only update those while the rest of the model is frozen. + # setting strict=False will ignore the missing keys (which are not being updated anyway) + # explicitly check if state_dict.keys matches all the expected self.adapter_keys since we don't have the + # safety in strict=True anymore. + if not self.ptuning_only_and_non_first_stage: + if set(state_dict.keys()) != self.adapter_keys.union(self.tunable_base_param_keys): + logging.warning( + f"Unexpected keys found in state_dict: {set(state_dict.keys()) - self.adapter_keys.union(self.tunable_base_param_keys)}, missing keys in state_dict: {self.adapter_keys.union(self.tunable_base_param_keys) - set(state_dict.keys())}" + ) + super(MegatronGPTModel, self).load_state_dict(state_dict, strict=False) + else: + super(MegatronGPTModel, self).load_state_dict(state_dict, strict=strict) + + def on_load_checkpoint(self, checkpoint) -> None: + """LightningModule hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-load-checkpoint + """ + checkpoint_state_dict = checkpoint['state_dict'] + self.load_state_dict(checkpoint_state_dict, strict=False) + + def setup_metric(self, data_cfg): + metric_name = "exact_string_match" + if not hasattr(data_cfg, "metric"): + metric = MetricStringToTorchMetric["exact_string_match"] + else: + if not hasattr(data_cfg.metric, "name"): + raise ValueError("Metric name is not provided in the metric config.") + if data_cfg.metric.name == "loss": + return None, "loss" + if data_cfg.metric.name not in MetricStringToTorchMetric: + raise KeyError( + f"{data_cfg.metric.name} is not supported. List of supported metrics: {MetricStringToTorchMetric.keys()}" + ) + if data_cfg.metric.name in self._metrics_require_string2category_map: + if data_cfg.metric.average is None: + raise ValueError( + f"{data_cfg.metric.name} requires specifying whether you want to compute a micro or macro average. Found None." + ) + if ( + data_cfg.metric.get('labels_are_strings', False) + and data_cfg.metric.name in self._metrics_require_string2category_map + ): + if data_cfg.metric.num_classes is None: + raise ValueError( + "Number of classes is not provided in the metric section within the data config. " + f"Please provide the number of classes in the data config to use the {data_cfg.metric.name} metric." + ) + if data_cfg.metric.get('class_labels', None) is None or not isinstance( + data_cfg.metric.get('class_labels', None), ListConfig + ): + raise ValueError( + "Class labels are not provided properly in the metric section witnin the data config. " + f"Please provide the class labels as a list of strings in the data config to use the {data_cfg.metric.name} metric." + ) + if len(data_cfg.metric.get('class_labels', None)) != data_cfg.metric.num_classes: + raise ValueError( + f"Number of class labels {len(data_cfg.metric.get('class_labels', None))} does not match `num_classes` : {data_cfg.metric.num_classes}" + ) + + metric_name = data_cfg.metric.name + metric_cls = MetricStringToTorchMetric[metric_name] + if metric_name not in TextMetricsSet: + metric = [metric_cls(**data_cfg.metric)] + else: + metric = [metric_cls()] + return metric, metric_name + + def inference_step(self, dataloader_iter, mode): + """ + Used for validation and test steps, added postprocessing after calling self.predict_step(). + """ + batch, batch_idx, dataloader_idx = next(dataloader_iter) + data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds + self._reconfigure_and_process_inference_batch(batch, data_cfg) + # Meta data from dataset + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + loss = super(MegatronGPTSFTModel, self).validation_step(itertools.chain([batch]), dataloader_idx) + + # We need _inference_config to get generation params + # add_BOS and tokens_to_generate are set in dataset + if self.get_inference_config() is None: + logging.warning(f'inference_config is not set. Use default: {default_inference_config}') + self.set_inference_config(inference_config=default_inference_config) + self._inference_config['add_BOS'] = data_cfg.add_bos + self._inference_config['tokens_to_generate'] = data_cfg.get('tokens_to_generate') + + output = self.predict_step(batch, batch_idx, dataloader_idx) + + inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] + labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] + preds_text = [ + self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')]) + for t, l in zip(output['token_ids'], batch['context_lengths']) + ] + + if data_cfg.get("end_string", None): + # sometimes data_cfg.end_string != self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(data_cfg.end_string)) + # for example when data_cfg.end_string = "", the end_string_re will start with " ?? " + end_string_re = self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(data_cfg.end_string)) + preds_text_cleaned = [] + labels_text_cleaned = [] + for p, l in zip(preds_text, labels_text): + # remove end_string from the end of the string + for es in [end_string_re, data_cfg.end_string]: + if p.endswith(es): + p = p[: -len(es)].strip() + if l.endswith(es): + l = l[: -len(es)].strip() + preds_text_cleaned.append(p) + labels_text_cleaned.append(l) + preds_text = preds_text_cleaned + labels_text = labels_text_cleaned + + if data_cfg.get("remove_text_pc", False): + preds_text = [remove_punctuations(p.lower(), data_cfg.get("punctuations", None)) for p in preds_text] + labels_text = [remove_punctuations(l.lower(), data_cfg.get("punctuations", None)) for l in labels_text] + + if data_cfg.get("log_every_n_steps", None) is not None: + if batch_idx % data_cfg.log_every_n_steps == 0: + logging.info(f"Input: `{inputs_text[0]}`") + logging.info(f"Label: `{labels_text[0]}`") + logging.info(f"Pred: `{preds_text[0]}`") + + # if loss is nan, print the input, label and pred + if loss.isnan(): + logging.info("++++++++++++++ NaN loss detected ++++++++++++++") + for i in range(len(inputs_text)): + logging.info(f"Input: `{inputs_text[i]}`") + logging.info(f"Label: `{labels_text[i]}`") + logging.info(f"Pred: `{preds_text[i]}`") + logging.info("++++++++++++++++++++++++++++++++++++++++++++++++") + + outputs = { + 'loss': loss, + 'preds': preds_text, # [str] + 'labels': labels_text, # [str] + 'inputs': inputs_text, # [str] + 'metadata': metadata, # [dict] + } + + if mode == 'validation': + if len(self._validation_dl) > 1: + # super().validation_step appends just loss to self.validation_step_outputs, replace the last appended loss with the outputs dict + self.validation_step_outputs[dataloader_idx][-1] = outputs + else: + # super().validation_step appends just loss to self.validation_step_outputs, replace the last appended loss with the outputs dict + self.validation_step_outputs[-1] = outputs + else: + if len(self._test_dl) > 1: + self.test_step_outputs[dataloader_idx][-1] = outputs + else: + self.test_step_outputs[-1] = outputs + return outputs + + def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int] = None): + """ + Used to get LLM predictions for validation and test steps based on the given inference config. + """ + inference_config = self.get_inference_config() + if inference_config is not None: + # need to overwrite some configuration, make it immutable + inference_config = inference_config.copy() + else: + self.set_inference_config(inference_config=default_inference_config) + logging.warning(f'inference_config is not set. Use default: {default_inference_config}') + inference_config = self.get_inference_config() + + if self.cfg.data.get('end_string', None): + inference_config['end_strings'] = [self.cfg.data.end_string] + + global_batch_size_per_gpu = batch['tokens'].size(0) + num_micro_batches_before_decode = get_num_microbatches() + + compute_logprob = inference_config.get('compute_logprob', False) + if compute_logprob: + inference_config['inputs'] = batch + inference_config['tokens_to_generate'] = 1 + inference_config['all_probs'] = True + inference_config["add_BOS"] = False + inference_config['greedy'] = True + response = generate(self, **inference_config) + response = get_computeprob_response(self.tokenizer, response, batch) + else: + # for megatron_gpt_eval.py + if isinstance(batch, list): + inference_config['inputs'] = batch + elif 'num_audios' in batch: + # peft_eval.py + inference_config['inputs'] = ( + batch['contexts'].cuda(), + batch['context_lengths'].cuda(), + batch['audio_signal'].cuda(), + batch['audio_signal_length'].cuda(), + batch['num_audios'].cuda(), + batch['context_start_idx'], + ) + else: + # peft_eval.py + inference_config['inputs'] = ( + batch['contexts'].cuda(), + batch['context_lengths'].cuda(), + batch['audio_signal'].cuda(), + batch['audio_signal_length'].cuda(), + ) + response = generate(self, **inference_config) + + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu // num_micro_batches_before_decode, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + # add audio offsets to context lengths for properly decoding only the response + batch['context_lengths'] = batch['context_lengths'].cuda() + response['audio_feat_lens'] + + return response + + def inference_epoch_end(self, outputs, mode, data_cfg): + # Parent class will handle logging of the loss. + if not outputs or (all([not x for x in outputs])): + return None + + if isinstance(outputs[0], dict): + outputs = [outputs] + + averaged_loss = [] + averaged_metric = [] + # Log metrics for each provided validation/test dataset. + for dataloader_idx, output in enumerate(outputs): + if len(output) == 0: + logging.warning(f"Empty output for dataloader_idx: {dataloader_idx}") + continue + # Expand on_validation_epoch_end from parent class MegatronGPTModel as on_validation_epoch_end doesnt take outputs arg + loss_vals = [x['loss'] for x in output] + if parallel_state.is_pipeline_last_stage(): + # only the last pipeline parallel stages return loss with their batch size + if self.cfg.data.get('validation_drop_last', True): + loss = torch.stack(loss_vals).mean() + else: + # Compute the avg loss by total_loss across all samples / total number of samples + total_loss_and_total_samples = torch.vstack(loss_vals).sum(axis=0) + avg_loss = total_loss_and_total_samples[0] / total_loss_and_total_samples[1] + loss = avg_loss.type(torch.float32).cuda() + else: + loss = torch.tensor(0.0, dtype=torch.float32).cuda() + + # we can only log on one rank if it is rank zero so we broadcast from last rank + torch.distributed.broadcast(loss, get_last_rank()) + + self.log('val_loss', loss, prog_bar=True, rank_zero_only=True, batch_size=1, sync_dist=True) + + # Determine the key used to log the loss based on the user provided name of the dataset or the dataloader index. + loss_log_key = self._determine_log_key(data_cfg, dataloader_idx, "loss", mode) + self.log(loss_log_key, loss, batch_size=1) + averaged_loss.append(loss) + + # Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks. + gathered_outputs = [None for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + gathered_outputs, + [ + {'preds': x['preds'], 'labels': x['labels'], 'inputs': x['inputs'], 'metadata': x['metadata']} + for x in output + ], + group=parallel_state.get_data_parallel_group(), + ) + + # Remove duplicate examples due to distributed sampler. + inp_label_set = set() + deduplicated_outputs = { + 'preds': [], + 'labels': [], + 'inputs': [], + 'metadata': [], + } + total_size = 0 + for rank in range(0, parallel_state.get_data_parallel_world_size()): + for batch in gathered_outputs[rank]: + for pred, label, input, metadata in zip( + batch['preds'], batch['labels'], batch['inputs'], batch['metadata'] + ): + key = input + label + str(metadata) + total_size += 1 + if key not in inp_label_set: + inp_label_set.add(key) + deduplicated_outputs['preds'].append(pred) + deduplicated_outputs['labels'].append(label) + deduplicated_outputs['inputs'].append(input) + deduplicated_outputs['metadata'].append(metadata) + + # Compute metric score + metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name + metric_label_key = self.val_metric_label_key if mode == 'validation' else self.test_metric_label_key + if metric_name != 'loss': + metric_log_key = self._determine_log_key(data_cfg, dataloader_idx, metric_name, mode) + metric_fn = self.val_metric[0] if mode == 'validation' else self.test_metric[0] + if metric_label_key in deduplicated_outputs['metadata'][0]: + labels = [m[metric_label_key] for m in deduplicated_outputs['metadata']] + else: + labels = deduplicated_outputs['labels'] + + # sacrebleu.corpus_bleu is commonly used which does not share + # the same interface as other metrics. We handle it separately. + if metric_name == 'bleu': + metric_result = torch.Tensor( + [sacrebleu.corpus_bleu(deduplicated_outputs['preds'], [labels]).score] + ).to(self.device) + else: + for pred, label in zip(deduplicated_outputs['preds'], labels): + _ = metric_fn(pred, label) + + metric_result = metric_fn.compute() + + if metric_name == 'rouge': + for k, v in metric_result.items(): + if 'fmeasure' in k: + self.log(metric_log_key + f'_{k}', v.item(), sync_dist=True, batch_size=1) + logging.info(f"{mode} {metric_name} {k}: {v.item()}") + metric_result = metric_result['rouge1_fmeasure'] + else: + self.log(metric_log_key, metric_result.item(), sync_dist=True, batch_size=1) + logging.info(f"{mode} {metric_name}: {metric_result.item()}") + + metric_fn.reset() + averaged_metric.append(metric_result) + + # Write predictions to file + if self.global_rank == 0 and data_cfg.get("write_predictions_to_file", False): + logging.info( + f"Total deduplicated inference data size: {total_size} to {len(deduplicated_outputs['inputs'])}" + ) + + # Check if the user provided a prefix path to the file(s) they want to write. + if not hasattr(data_cfg, "output_file_path_prefix") or data_cfg.output_file_path_prefix is None: + raise ValueError( + f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file." + ) + filename_log_key = self._determine_log_key(data_cfg, dataloader_idx, None, mode) + output_dir = data_cfg.get("output_dir", "./") + self.write_predictions_to_file( + deduplicated_outputs, f"{data_cfg.output_file_path_prefix}_{filename_log_key}", output_dir + ) + + torch.distributed.barrier(group=parallel_state.get_data_parallel_group()) + outputs[dataloader_idx].clear() # free memory + + # Logging of the averaged metrics: + averaged_loss = sum(averaged_loss) / len(averaged_loss) + averaged_metric = sum(averaged_metric) / len(averaged_metric) if len(averaged_metric) > 0 else None + averaged_loss = averaged_loss.to(self.device) + if averaged_metric is not None: + averaged_metric = averaged_metric.to(self.device) + + # Handle case where metrics can be nan or inf. This can break checkpoint save/load. + if averaged_metric is not None and (torch.isinf(averaged_metric) or torch.isnan(averaged_metric)): + app_state = AppState() + monitor_mode = app_state.checkpoint_callback_params.mode + assert monitor_mode in ['min', 'max'] + averaged_metric = 0.0 if monitor_mode == 'max' else 1e5 + + if mode == 'validation': + self.log("validation_loss", averaged_loss, batch_size=1, sync_dist=True) + if averaged_metric is not None: + self.log(f"validation_{self.val_metric_name}", averaged_metric, sync_dist=True, batch_size=1) + elif mode == 'test': + self.log("test_loss", averaged_loss, batch_size=1, sync_dist=True) + if averaged_metric is not None: + self.log(f"test_{self.test_metric_name}", averaged_metric, sync_dist=True, batch_size=1) + + # Merge the functionality of previous on_inference_epoch_end() within inference_epoch_end() func here + app_state = AppState() + self._restore_activation_checkpointing_args() + if hasattr(self, "_train_ds"): + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.train_ds.global_batch_size, + micro_batch_size=self.cfg.data.train_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + # When running `trainer.validate()`, the training dataset is not available. + else: + logging.warning('No training data found, reconfiguring microbatches based on validation batch sizes.') + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=data_cfg.global_batch_size, + micro_batch_size=data_cfg.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + return averaged_loss, averaged_metric + + # consistent with speech models + @rank_zero_only + def write_predictions_to_file(self, outputs, output_file_path_prefix, output_dir): + os.makedirs(output_dir, exist_ok=True) + output_file_path = output_file_path_prefix + "_inputs_preds_labels.jsonl" + output_file_path = os.path.join(output_dir, output_file_path) + with open(output_file_path, "w") as f_json: + assert ( + len(outputs['inputs']) == len(outputs['preds']) == len(outputs['labels']) == len(outputs['metadata']) + ) + for i, p, l, m in zip(outputs['inputs'], outputs['preds'], outputs['labels'], outputs['metadata']): + json_string = {'input': i, 'pred_text': p, 'text': l} + for k, v in m.items(): + if k not in json_string: + json_string[k] = v + f_json.write(json.dumps(json_string) + '\n') + + logging.info(f'Predictions saved to {output_file_path}') + + def setup_eval_dataloader(self, datasets, data_cfg): + dataloaders = [] + if not isinstance(datasets, list): + return self.build_data_loader(dataset=datasets, data_cfg=data_cfg, consumed_samples=0) + for dataset in datasets: + eval_dl = self.build_data_loader(dataset=dataset, data_cfg=data_cfg, consumed_samples=0) + dataloaders.append(eval_dl) + return dataloaders + + def setup_predict_dataloader(self, data_cfg): + datasets = self._build_dataset(data_cfg, False) + dataloaders = [] + if not isinstance(datasets, list): + return self.build_data_loader(dataset=datasets, data_cfg=data_cfg, consumed_samples=0, is_predict=True) + for dataset in datasets: + eval_dl = self.build_data_loader(dataset=dataset, data_cfg=data_cfg, consumed_samples=0, is_predict=True) + dataloaders.append(eval_dl) + return dataloaders + + def sharded_state_dict(self, prefix: str = ''): + """ + Force None for the parent class's sharded_state_dict() method if setup is complete. + """ + if self.setup_complete: + return None + else: + return super().sharded_state_dict(prefix=prefix) + + def maybe_build_test(self): + # overwrite the parent class's maybe_build_test() method in MegatronGPTModel + if hasattr(self.cfg.data, 'test_ds'): + logging.info('Building test datasets...') + # Wrap this in a list since the general finetuning parent class supports multi-validation. + self._test_ds = self._build_dataset(self.cfg.data.test_ds, is_train=False) + lengths = [len(x) for x in self._test_ds] + logging.info(f'Length of test datasets: {lengths}, total: {sum(lengths)}') + return + + def maybe_setup_test(self): + # overwrite the parent class's maybe_build_test() method in MegatronGPTModel + if hasattr(self.cfg.data, 'test_ds'): + self._test_dl = self.setup_eval_dataloader(self._test_ds, self.cfg.data.test_ds) + return + + def build_train_valid_test_datasets(self, stage): + if stage != 'test': + logging.info('Building validation datasets.') + # Wrap this in a list since the general finetuning parent class supports multi-validation. + self._validation_ds = self._build_dataset(self.cfg.data.validation_ds, is_train=False) + lengths = [len(x) for x in self._validation_ds] + logging.info(f'Length of validation datasets: {lengths}, total: {sum(lengths)}') + + if stage != 'validate': + self.maybe_build_test() + + if stage == 'validate' or stage == 'test': + return + logging.info('Building training datasets.') + self._train_ds = self._build_dataset(self.cfg.data.train_ds) + logging.info(f'Length training datasets: {len(self._train_ds)}') + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + + model = PretrainedModelInfo( + pretrained_model_name="speechllm_fc_llama2_7b", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia/nemo/speechllm_fc_llama2_7b", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/speechllm_fc_llama2_7b/versions/1.23.1/files/speechllm_fc_llama2_7b.nemo", + ) + results.append(model) + return results diff --git a/nemo/collections/multimodal/speech_llm/modules/__init__.py b/nemo/collections/multimodal/speech_llm/modules/__init__.py new file mode 100644 index 000000000000..d9562652ce84 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/modules/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.multimodal.speech_llm.modules.modality_adapters import PoolingMLPConnectors +from nemo.collections.multimodal.speech_llm.modules.perception_modules import ( + AudioPerceptionModule, + MultiAudioPerceptionModule, + MultiFeatureAggregator, +) diff --git a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py new file mode 100644 index 000000000000..0cd48502bb84 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py @@ -0,0 +1,175 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import torch + +import nemo.collections.nlp.modules.common.text_generation_strategy as text_generation_strategy +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import shift_tokens_by_multi_audios + + +# the text representation of eos_id, it applies for all tokenizers +END_OF_SEQ = '<|endoftext|>' + + +def switch(val1, val2, boolean): + boolean = boolean.type_as(val1) + boolean = boolean.unsqueeze(0).unsqueeze(-1) + return (1 - boolean) * val1 + boolean * val2 + + +class AudioToTextGenerationStrategy(text_generation_strategy.GPTModelTextGenerationStrategy): + def init_batch( + self, + context_tokens: torch.Tensor, + context_lengths: torch.Tensor, + audio_signal: torch.Tensor, + audio_length: torch.Tensor, + compute_attention_mask: bool, + num_audios: Optional[torch.Tensor] = None, + context_start_idx: Optional[List[List[int]]] = None, + ): + """initialize the batch data before the inference steps.""" + # Move to GPU. + + audio_feats, audio_feat_lens = self.model.perception( + input_signal=audio_signal, + input_signal_length=audio_length, + processed_signal=None, + processed_signal_length=None, + ) + + if num_audios is not None: + # handle multiple audio files per sample + audio_feats = audio_feats.split(num_audios.tolist()) + audio_feat_lens = audio_feat_lens.split(num_audios.tolist()) + + encoder_input, attention_mask, _, position_ids, encoder_max_length = self.model.inject_perception_input( + audio_feats, audio_feat_lens, context_tokens, context_lengths, context_start_idx + ) + + self.attention_mask = attention_mask + self.position_ids = position_ids + + if num_audios is not None: + # handle multiple audio files per sample + new_context_tokens = shift_tokens_by_multi_audios( + context_tokens, context_lengths, audio_feat_lens, context_start_idx, encoder_max_length + ) + audio_feat_lens = torch.stack([torch.sum(lens) for lens in audio_feat_lens]) # [batch,] + else: + new_context_tokens = self.model._shift_labels_by_emb_len( + context_tokens, context_lengths, audio_feat_lens, encoder_max_length, pad_token=0 + ) + + return new_context_tokens, encoder_input, audio_feat_lens + + def clip_max_len(self, maxlen: int) -> int: + """clip the max len based on the LM model max sequence length""" + # for positional embedding types that allow length extrapolation, don't clip the max length + if self.model.cfg.get("position_embedding_type", "learned_absolute") == "learned_absolute": + if maxlen > self.model.cfg.encoder_seq_length + 1: + maxlen = self.model.cfg.encoder_seq_length + 1 + return maxlen + + def prepare_batch_at_step( + self, + tokens: torch.Tensor, + input_embeddings: torch.Tensor, + maxlen: int, + micro_batch_size: int, + step: int, + context_lengths: torch.Tensor, + curr_context_length: int, + compute_attention_mask: bool, + ) -> Tuple[List[torch.Tensor], List[int]]: + # types2use = None + if step == 0: + # Allocate memory for the entire context. + set_inference_key_value_memory = True + tokens2use = tokens[:, :curr_context_length] + positions2use = self.position_ids[:, :curr_context_length] + embeddings2use = input_embeddings[:curr_context_length] + else: + # Set this to false so the memory is not reallocated. + set_inference_key_value_memory = False + tokens2use = tokens[:, curr_context_length - 1].view(micro_batch_size, -1) + positions2use = self.position_ids[:, curr_context_length - 1].view(micro_batch_size, -1) + embeddings2use = self.model._get_text_embeddings(tokens2use, positions2use) + started = context_lengths <= curr_context_length + embeddings2use = switch(input_embeddings[curr_context_length - 1].unsqueeze(0), embeddings2use, started) + + """Prepare batch for each of the inference steps""" + setkey_value_array = torch.tensor( + [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device() + ) + len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) + + batch = [tokens2use, embeddings2use, self.attention_mask, positions2use, setkey_value_array, len_array] + tensor_shape = [tokens2use.shape[1], micro_batch_size, self.model.cfg.hidden_size] + return batch, tensor_shape + + def post_process(self, tokens: torch.Tensor, new_tokens: torch.Tensor, context_length: int): + """ + At the end of the inference, post process the inference results + """ + pass + + def end_of_generation_condition( + self, tokens: torch.Tensor, prev: torch.Tensor, eod_id: int, end_strings: List[str] + ) -> torch.Tensor: + """ + return whether the generation should stop based on the previous token + Args: + tokens (torch.Tensor): the generated tokens so far + prev (torch.Tensor): the previous token + eod_id (int): the end of document token id + end_strings (List[str]): the list of end of generation strings + returns: + a boolean tensor indicating whether the generation should stop + """ + if len(end_strings) == 1 and end_strings[0] == END_OF_SEQ: + return prev == eod_id + else: + tokenizer = self.model.tokenizer + conditions = [] + end_tokens = set() + end_tokens.add(eod_id) + for end_string in end_strings: + if len(end_string) > 1: + continue + ids_1 = tokenizer.text_to_ids(f'{end_string}') + ids_2 = tokenizer.text_to_ids('') + if len(ids_1) <= len(ids_2): + continue + token_id = ids_1[len(ids_2) :][0] + + end_tokens.add(token_id) + + for p, token_item in zip(prev, tokens): + text = tokenizer.ids_to_text(token_item.tolist()) + conditions.append( + any([text.endswith(end_string) for end_string in end_strings] + [p.item() in end_tokens]) + ) + return torch.tensor(conditions, dtype=torch.bool, device=tokens.device) + + +def model_inference_strategy_dispatcher(model, **args): + from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel + + if isinstance(model, ModularAudioGPTModel): + return AudioToTextGenerationStrategy(model, **args) + else: + return text_generation_strategy.model_inference_strategy_dispatcher(model, **args) diff --git a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py new file mode 100644 index 000000000000..136418031586 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py @@ -0,0 +1,698 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for generating text.""" + +import pickle +from collections.abc import Iterable +from typing import List, Optional, Tuple, Union +import numpy as np +import torch +import torch.nn.functional as F + +import nemo.collections.nlp.modules.common.text_generation_utils as text_generation_utils +from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer +from nemo.collections.multimodal.speech_llm.modules.common.audio_text_generation_strategy import ( + model_inference_strategy_dispatcher, +) +from nemo.collections.nlp.modules.common.transformer.text_generation import OutputType +from nemo.utils import AppState + +try: + from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + + HAVE_APEX = False + +try: + from megatron.core import parallel_state, tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +__all__ = [ + "get_computeprob_response", + "generate", +] + + +def get_computeprob_response(tokenizer, response, inputs): + return text_generation_utils.get_computeprob_response(tokenizer, response, inputs) + + +def send_generate_info( + context_tokens_tensor, + context_length_tensor, + audio_signal, + audio_signal_length, + tokens_to_generate, + all_probs, + compute_logprob, + temperature, + top_k, + top_p, + greedy, + repetition_penalty, + min_tokens_to_generate, + end_strings, + num_audios: Optional[torch.Tensor] = None, + context_start_idx: Optional[List[List[int]]] = None, +): + """ + Needs to be synced up with receive_generate_info + """ + model_parallel_group = parallel_state.get_model_parallel_group() + src = text_generation_utils.get_model_parallel_src_rank() + + audio_max_len = audio_signal.size(1) if audio_signal is not None else 0 + + # Send the sizes of the tensors + input_info = [ + context_tokens_tensor.size(0), # batch_size + context_tokens_tensor.size(1), # seq_len + audio_max_len, # audio_max_len + tokens_to_generate, + all_probs, + compute_logprob, # whether to compute log probabilities matrix + temperature, + top_k, + top_p, + greedy, + repetition_penalty, + min_tokens_to_generate, + ] + input_info_tensor = torch.cuda.FloatTensor(input_info) + torch.distributed.broadcast(input_info_tensor, src, model_parallel_group) + + # Send variables to all ranks + torch.distributed.broadcast(context_length_tensor, src, model_parallel_group) + torch.distributed.broadcast(context_tokens_tensor, src, model_parallel_group) + + torch.distributed.broadcast(audio_signal, src, model_parallel_group) + torch.distributed.broadcast(audio_signal_length, src, model_parallel_group) + + # send end strings + string_tensor = torch.as_tensor( + np.frombuffer(pickle.dumps(end_strings), dtype=np.int8), device=torch.cuda.current_device() + ) + size = torch.as_tensor([string_tensor.size(0)], device=torch.cuda.current_device(), dtype=torch.int64) + torch.distributed.broadcast(size, src, model_parallel_group) + torch.distributed.broadcast(string_tensor, src, model_parallel_group) + + if num_audios is not None: + torch.distributed.broadcast(num_audios, src, model_parallel_group) + + if context_start_idx is not None: + context_idx_tensor = torch.as_tensor( + np.frombuffer(pickle.dumps(context_start_idx), dtype=np.int8), device=torch.cuda.current_device() + ) + ctx_size = torch.as_tensor([context_idx_tensor.size(0)], device=torch.cuda.current_device(), dtype=torch.int64) + torch.distributed.broadcast(ctx_size, src, model_parallel_group) + torch.distributed.broadcast(context_idx_tensor, src, model_parallel_group) + + +def receive_generate_info(has_multi_audios=False): + """ + Needs to be synced up with send_generate_info + """ + model_parallel_group = parallel_state.get_model_parallel_group() + src = text_generation_utils.get_model_parallel_src_rank() + input_info_tensor = torch.empty(12, dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.broadcast(input_info_tensor, src, model_parallel_group) + batch_size = int(input_info_tensor[0].item()) + seq_len = int(input_info_tensor[1].item()) + audio_len = int(input_info_tensor[2].item()) + tokens_to_generate = int(input_info_tensor[3].item()) + all_probs = bool(input_info_tensor[4].item()) + compute_logprob = bool(input_info_tensor[5].item()) # whether to compute log probabilities matrix + temperature = float(input_info_tensor[6].item()) + top_k = int(input_info_tensor[7].item()) + top_p = float(input_info_tensor[8].item()) + greedy = bool(input_info_tensor[9].item()) + repetition_penalty = float(input_info_tensor[10].item()) + min_tokens_to_generate = int(input_info_tensor[11].item()) + + context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) + context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device()) + # Send variables to all ranks + torch.distributed.broadcast(context_length_tensor, src, model_parallel_group) + torch.distributed.broadcast(context_tokens_tensor, src, model_parallel_group) + + audio_signal = torch.empty(batch_size, audio_len, dtype=torch.float32, device=torch.cuda.current_device()) + audio_signal_length = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) + # Send variables to all ranks + torch.distributed.broadcast(audio_signal, src, model_parallel_group) + torch.distributed.broadcast(audio_signal_length, src, model_parallel_group) + + array_size = torch.empty(1, dtype=torch.int64, device=torch.cuda.current_device()) + torch.distributed.broadcast(array_size, src, model_parallel_group) + + string_tensor = torch.empty(array_size[0], dtype=torch.int8, device=torch.cuda.current_device()) + torch.distributed.broadcast(string_tensor, src, model_parallel_group) + bytes = string_tensor.cpu().numpy().tobytes() + end_strings = pickle.loads(bytes) + + num_audios = None + context_start_idx = None + if has_multi_audios: + num_audios = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) + torch.distributed.broadcast(num_audios, src, model_parallel_group) + + array_size = torch.empty(1, dtype=torch.int64, device=torch.cuda.current_device()) + torch.distributed.broadcast(array_size, src, model_parallel_group) + context_idx_tensor = torch.empty(array_size[0], dtype=torch.int8, device=torch.cuda.current_device()) + torch.distributed.broadcast(context_idx_tensor, src, model_parallel_group) + bytes = context_idx_tensor.cpu().numpy().tobytes() + context_start_idx = pickle.loads(bytes) + + return ( + context_length_tensor, + context_tokens_tensor, + audio_signal, + audio_signal_length, + tokens_to_generate, + all_probs, + compute_logprob, + temperature, + top_k, + top_p, + greedy, + repetition_penalty, + min_tokens_to_generate, + end_strings, + num_audios, + context_start_idx, + ) + + +def synced_generate( + model, + inference_strategy, + context_tokens_tensor, + context_length_tensor, + audio_signal, + audio_signal_length, + tokens_to_generate, + all_probs, + temperature, + top_k=0, + top_p=0.0, + greedy=False, + compute_attention_mask=True, + compute_logprob=False, + repetition_penalty=1.2, + end_strings=[], + min_tokens_to_generate=0, + num_audios: Optional[torch.Tensor] = None, + context_start_idx: Optional[List[List[int]]] = None, +): + context_length = context_length_tensor.min().item() + tokenizer = model.tokenizer + if isinstance(tokenizer, TabularTokenizer): + raise NotImplementedError("Tabular generation is not supported yet") + else: + batch_token_iterator = sample_sequence_batch( + model, + inference_strategy, + context_tokens_tensor, + context_length_tensor, + audio_signal, + audio_signal_length, + tokens_to_generate, + all_probs, + compute_attention_mask=compute_attention_mask, + compute_logprob=compute_logprob, + temperature=temperature, + end_strings=end_strings, + extra={ + "top_p": top_p, + "top_k": top_k, + "greedy": greedy, + "repetition_penalty": repetition_penalty, + "min_tokens_to_generate": min_tokens_to_generate, + }, + num_audios=num_audios, + context_start_idx=context_start_idx, + ) + + for tokens, lengths, output_logits, full_logits, audio_feat_lens in batch_token_iterator: + context_length += 1 + context_length += audio_feat_lens.min().item() + if parallel_state.is_pipeline_last_stage(): + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_embedding_group() + if compute_logprob: + torch.distributed.broadcast(output_logits, src, group) + if all_probs: + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_embedding_group() + torch.distributed.broadcast(full_logits, src, group) + + else: + if parallel_state.is_pipeline_first_stage(): + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_embedding_group() + + if compute_logprob: + precision = model._trainer.precision + if precision in [16, "16"]: + dtype = torch.float16 + elif precision == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + output_logits = torch.empty( + tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda") + ) + torch.distributed.broadcast(output_logits, src, group) + + if all_probs: + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_embedding_group() + full_logits = torch.empty( + tokens.size(0), + context_length - 1, + model.padded_vocab_size, + dtype=dtype, + device=torch.device("cuda"), + ) + torch.distributed.broadcast(full_logits, src, group) + if tokens is not None: + return tokens[:, :context_length], output_logits, full_logits, audio_feat_lens + return None + + +def generate( + model, + inputs: Union[Tuple, List[str]], + tokens_to_generate=0, + all_probs=False, + temperature=1.0, + add_BOS=False, + top_k=0, + top_p=0.0, + greedy=False, + compute_attention_mask=True, + compute_logprob=False, + repetition_penalty=1.0, + end_strings=['<|endoftext|>'], + min_tokens_to_generate=0, + **strategy_args, +) -> OutputType: + """ + Args: + model (NLPModel): text generative model + inputs (Union[tuple, List[str]]): if it is a tuple, it is assumed to be (context_tokens_tensor, context_length_tensor). Otherwise it it a list of prompt text strings + tokens_to_generate (int): The maximum length of the tokens to be generated. + all_probs (bool): Return the log prob for all the tokens + temperature (float): sampling temperature + add_BOS (bool): add the bos token at the begining of the prompt + top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (float): If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + greedy (bool): Whether or not to use sampling ; use greedy decoding otherwise + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty + min_tokens_to_generate (int): The minimum length of the tokens to be generated + strategy_args, the extra arguments are treated as inference strategy arguments + end_strings, a list of strings to stop generation when they are encountered in the output. + Returns: + OutputType: It generates the output in a dictionary type. It has the following keys: + sentences: List[str], output sentences + tokens: List[List[str]], output sentences borken into tokens + logprob: List[Tensor], log prob of generated tokens + full_logprob: List[Tensor], log prob of all the tokens in the vocab + token_ids: List[Tensor], output sentence token ids + offsets: List[List[int]] # list of tokens start positions in text + """ + if 'strategy' in strategy_args: + inference_strategy = strategy_args['strategy'] + else: + inference_strategy = model_inference_strategy_dispatcher(model) + tokenizer = model.tokenizer + has_multi_audios = False + num_audios = None + context_start_idx = None + audio_signal, audio_signal_length = None, None + if torch.distributed.get_rank() == text_generation_utils.get_model_parallel_src_rank(): + if isinstance(inputs, tuple) and len(inputs) == 2: + context_tokens_tensor, context_length_tensor = inputs + elif isinstance(inputs, tuple) and len(inputs) == 4: + context_tokens_tensor, context_length_tensor, audio_signal, audio_signal_length = inputs + elif isinstance(inputs, tuple) and len(inputs) == 6: # multi-audio + has_multi_audios = True + ( + context_tokens_tensor, + context_length_tensor, + audio_signal, + audio_signal_length, + num_audios, + context_start_idx, + ) = inputs + else: + context_tokens_tensor, context_length_tensor = inference_strategy.tokenize_batch( + inputs, tokens_to_generate, add_BOS + ) + + send_generate_info( + context_tokens_tensor, + context_length_tensor, + audio_signal, + audio_signal_length, + tokens_to_generate, + all_probs, + compute_logprob, + temperature, + top_k, + top_p, + greedy, + repetition_penalty, + min_tokens_to_generate, + end_strings, + num_audios, + context_start_idx, + ) + else: + ( + context_length_tensor, + context_tokens_tensor, + audio_signal, + audio_signal_length, + tokens_to_generate, + all_probs, + compute_logprob, + temperature, + top_k, + top_p, + greedy, + repetition_penalty, + min_tokens_to_generate, + end_strings, + num_audios, + context_start_idx, + ) = receive_generate_info(has_multi_audios) + + output = synced_generate( + model, + inference_strategy, + context_tokens_tensor, + context_length_tensor, + audio_signal, + audio_signal_length, + tokens_to_generate, + all_probs, + temperature, + compute_attention_mask=compute_attention_mask, + compute_logprob=compute_logprob, + top_k=top_k, + top_p=top_p, + greedy=greedy, + repetition_penalty=repetition_penalty, + end_strings=end_strings, + min_tokens_to_generate=min_tokens_to_generate, + num_audios=num_audios, + context_start_idx=context_start_idx, + ) + special_tokens = set() + if hasattr(tokenizer, 'pad_token') and tokenizer.pad_token is not None: + special_tokens.add(tokenizer.pad_token) + if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token is not None: + special_tokens.add(tokenizer.eos_token) + if hasattr(tokenizer, 'bos_token') and tokenizer.bos_token is not None: + special_tokens.add(tokenizer.bos_token) + if hasattr(tokenizer, 'cls_token') and tokenizer.cls_token is not None: + special_tokens.add(tokenizer.cls_token) + if hasattr(tokenizer, 'unk_token') and tokenizer.unk_token is not None: + special_tokens.add(tokenizer.unk_token) + if hasattr(tokenizer, 'sep_token') and tokenizer.sep_token is not None: + special_tokens.add(tokenizer.sep_token) + if hasattr(tokenizer, 'mask_token') and tokenizer.mask_token is not None: + special_tokens.add(tokenizer.mask_token) + if output is not None: + decode_tokens, output_logits, full_logits, audio_feat_lens = output + resp_sentences = [] + resp_sentences_seg = [] + + decode_tokens = decode_tokens.cpu().numpy().tolist() + for decode_token in decode_tokens: + sentence = tokenizer.ids_to_text(decode_token) + resp_sentences.append(sentence) + if not isinstance(tokenizer, TabularTokenizer): + words = [] + for token in decode_token: + if not isinstance(token, Iterable): + token = [token] + word = tokenizer.ids_to_tokens(token) + if isinstance(word, Iterable): + word = word[0] + if hasattr(tokenizer.tokenizer, 'byte_decoder'): + word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( + 'utf-8', errors='replace' + ) + words.append(word) + resp_sentences_seg.append(words) + else: + words = tokenizer.text_to_tokens(sentence) + resp_sentences_seg.append(words) + + # offsets calculation + all_offsets = [] + for item in resp_sentences_seg: + offsets = [0] + for index, token in enumerate(item): + if index != len(item) - 1: + if token in special_tokens: + offsets.append(offsets[-1]) + else: + offsets.append(len(token) + offsets[-1]) + all_offsets.append(offsets) + + output = {} + output['sentences'] = resp_sentences + output['tokens'] = resp_sentences_seg + output['logprob'] = output_logits + output['full_logprob'] = full_logits + output['token_ids'] = decode_tokens + output['offsets'] = all_offsets + output['audio_feat_lens'] = audio_feat_lens + output = inference_strategy.post_generation_process(output) + return output + return None + + +def switch(val1, val2, boolean): + boolean = boolean.type_as(val1) + return (1 - boolean) * val1 + boolean * val2 + + +def sample_sequence_batch( + model, + inference_strategy, + context_tokens, + context_lengths, + audio_signal, + audio_signal_length, + tokens_to_generate, + all_probs=False, + compute_attention_mask=True, + compute_logprob=False, + type_ids=None, + temperature=None, + end_strings=['<|endoftext|>'], + extra={}, + num_audios: Optional[torch.Tensor] = None, + context_start_idx: Optional[List[List[int]]] = None, +): + app_state = AppState() + micro_batch_size = context_tokens.shape[0] + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=micro_batch_size, + micro_batch_size=micro_batch_size, + data_parallel_size=1, + ) + assert tokens_to_generate > 0, "tokens_to_generate should be > 0" + assert ( + model.cfg.get('sequence_parallel', False) == False + ), 'sequence_parallel should be False during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint' + assert ( + model.cfg.get('activations_checkpoint_granularity', None) is None + ), 'activations_checkpoint_granularity should be None during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint' + assert ( + model.cfg.get('activations_checkpoint_method', None) is None + ), 'activations_checkpoint_method should be None during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint' + + tokenizer = model.tokenizer + # initialize the batch + with torch.no_grad(): + context_tokens, input_embeddings, audio_feat_lens = inference_strategy.init_batch( + context_tokens, + context_lengths, + audio_signal, + audio_signal_length, + compute_attention_mask, + num_audios, + context_start_idx, + ) + audio_text_context_lengths = context_lengths + audio_feat_lens + context_length = audio_text_context_lengths.min().item() + # added eos_id to support the function generate_samples_eval that passes + # eos_id as an argument and needs termination when that id id found. + eod_id = tokenizer.eos_id + counter = 0 + batch_size = context_tokens.size(0) + is_done = torch.zeros([batch_size]).byte().cuda() + tokens = context_tokens + output_logits = None + all_generated_indices = None # used to track all generated indices + # Generate enough tokens for the longest sequence + maxlen = tokens_to_generate + audio_text_context_lengths.max().item() + maxlen = inference_strategy.clip_max_len(maxlen) + lengths = torch.ones([batch_size]).long().cuda() * maxlen + while context_length < maxlen: + batch, tensor_shape = inference_strategy.prepare_batch_at_step( + tokens, + input_embeddings, + maxlen, + micro_batch_size, + counter, + audio_text_context_lengths, + context_length, + compute_attention_mask, + ) + output = inference_strategy.forward_step(batch, tensor_shape) + if parallel_state.is_pipeline_last_stage(): + if compute_logprob: + output = output[0]['logits'] + output = tensor_parallel.gather_from_tensor_model_parallel_region(output) + assert output is not None + logits = output[:, -1].view(batch_size, -1).contiguous() + + else: + logits = output[0]['logits'][:, -1].contiguous() + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + assert logits is not None + logits = logits.view(batch_size, -1) + + # make sure it will generate at least min_length + min_length = extra.get('min_tokens_to_generate', 0) + if min_length > 0: + within_min_length = (context_length - audio_text_context_lengths) < min_length + logits[within_min_length, eod_id] = -float('Inf') + # make sure it won't sample outside the vocab_size range + logits[:, tokenizer.vocab_size :] = -float('Inf') + + # started indicates whether the current token step passes the context_length, so we make sure not to overwrite the context tokens + started = audio_text_context_lengths <= context_length + if extra.get('greedy', False): + prev = torch.argmax(logits, dim=-1).view(-1) + else: + logits = logits.float() + logits /= temperature + # handle repetition penality + logits = text_generation_utils.repetition_penalty( + logits, extra.get('repetition_penalty', 1.2), all_generated_indices + ) + logits = text_generation_utils.top_k_logits( + logits, top_k=extra.get('top_k', 0), top_p=extra.get('top_p', 0.9), started=started + ) + probs = F.softmax(logits, dim=-1) + # TODO(zhehuai) + probs = probs.nan_to_num(1.0) + prev = torch.multinomial(probs, num_samples=1).view(-1) + + # Clamp the predicted out of vocabulary tokens + prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) + new_tokens = switch(tokens[:, context_length].view(-1), prev, started) + + # Replace sampled tokens w/ done token if EOD has already been sampled + new_tokens = switch(new_tokens, eod_id, is_done) + + # post process the inference tokens based on the strategy + inference_strategy.post_process(tokens, new_tokens, context_length) + + # Insert either new predicted or next prompt token + tokens[:, context_length] = new_tokens + + if compute_logprob: + if output_logits is None: + output = F.log_softmax(output[:, :context_length, :], 2) + + indices = torch.unsqueeze(tokens[:, 1 : context_length + 1], 2) + output_logits = torch.gather(output, 2, indices).squeeze(2) + all_generated_indices = indices[:, :, 0] + if all_probs: + full_logits = output + else: + output = F.log_softmax(output, 2) + indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) + new_output_logits = torch.gather(output, 2, indices).squeeze(2) + + # TODO(rprenger) we're copying output_logits every time. Should pre-allocate + output_logits = torch.cat([output_logits, new_output_logits], 1) + all_generated_indices = torch.cat([all_generated_indices, indices[:, :, 0]], 1) + if all_probs: + full_logits = torch.cat([full_logits, output], 1) + + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_embedding_group() + torch.distributed.broadcast(new_tokens, src, group) + + # done_token = (prev == eod_id).byte() & started.byte() + done_token = inference_strategy.end_of_generation_condition( + tokens[:, : context_length + 1], prev, eod_id, end_strings + ) + done_token = done_token.byte() & started.byte() + + just_finished = (done_token & ~is_done).bool() + lengths[just_finished.view(-1)] = context_length + is_done = is_done | done_token + + done = torch.all(is_done) + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_pipeline_model_parallel_group() + torch.distributed.broadcast(done, src, group) + if compute_logprob: + if all_probs: + yield tokens, lengths, output_logits, full_logits, audio_feat_lens + else: + yield tokens, lengths, output_logits, None, audio_feat_lens + else: + yield tokens, lengths, None, None, audio_feat_lens + + else: + if parallel_state.is_pipeline_first_stage(): + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_embedding_group() + new_tokens = torch.empty_like(tokens[:, context_length]) + torch.distributed.broadcast(new_tokens, src, group) + tokens[:, context_length] = new_tokens + yield tokens, None, None, None, audio_feat_lens + else: + yield None, None, None, None, audio_feat_lens + + done = torch.cuda.ByteTensor([0]) + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_pipeline_model_parallel_group() + torch.distributed.broadcast(done, src, group) + + context_length += 1 + counter += 1 + if done: + break diff --git a/nemo/collections/multimodal/speech_llm/modules/modality_adapters.py b/nemo/collections/multimodal/speech_llm/modules/modality_adapters.py new file mode 100644 index 000000000000..408231adcc6d --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/modules/modality_adapters.py @@ -0,0 +1,134 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import torch +import torch.nn as nn + +from nemo.collections.common.parts.multi_layer_perceptron import MultiLayerPerceptron as MLP +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.mixins import AccessMixin +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType + +__all__ = ['PoolingMLPConnectors'] + + +class ConcatPooling(nn.Module): + """ + A module that perform pooling by concatenating the features of every pooling_factor frames. + """ + + def __init__(self, pooling_factor): + super().__init__() + self.pooling_factor = pooling_factor + + def forward(self, x): + # x: [batch_size, seq_len, input_dim] + batch_size, seq_len, input_dim = x.shape + if seq_len % self.pooling_factor != 0: + x = x[:, : -(seq_len % self.pooling_factor), :] + x = x.reshape(batch_size, seq_len // self.pooling_factor, input_dim * self.pooling_factor) + return x + + +class PoolingMLPConnectors(NeuralModule, Exportable, AccessMixin): + """ + A module that performs pooling and MLP on the input features. + Currently only supports mean pooling and concatenation pooling. + """ + + def __init__( + self, + input_dim, + hidden_dim, + output_dim=None, + num_layers: int = 2, + activation: str = "relu", + pooling: str = "mean", + pooling_factor: int = 2, + **kwargs, # keep this to avoid breaking existing code + ): + """ + Args: + input_dim: input dimension of the features + hidden_dim: hidden dimension of the MLP layers + output_dim: output dimension of the features + num_layers: number of layers in the MLP + activation: activation function used in MLP + pooling: type of pooling, currently only supports "mean" and "cat" + pooling_factor: size of the pooling window + """ + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim if output_dim else input_dim + self.num_layers = num_layers + self.activation = activation + self.pooling = pooling + self.pooling_factor = pooling_factor + + if num_layers == 1: + self.hidden_dim = output_dim + + if pooling == "cat": + self.preprocess = nn.Sequential( + ConcatPooling(pooling_factor), nn.Linear(input_dim * pooling_factor, self.hidden_dim) + ) + else: + self.preprocess = nn.Sequential( + nn.AvgPool1d(pooling_factor, stride=pooling_factor), nn.Linear(input_dim, self.hidden_dim) + ) + + if num_layers == 1: + self.mlp = nn.Identity() + else: + self.mlp = MLP(self.hidden_dim, output_dim, num_layers, activation, log_softmax=False) + + @property + def input_types(self): + """Returns definitions of module input ports.""" + return OrderedDict( + { + "audio_signal": NeuralType(("B", "D", "T"), AcousticEncodedRepresentation()), + "length": NeuralType(tuple("B"), LengthsType()), + } + ) + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return OrderedDict( + { + "outputs": NeuralType(("B", "D", "T"), AcousticEncodedRepresentation()), + "outputs_len": NeuralType(tuple("B"), LengthsType()), + } + ) + + @typecheck() + def forward(self, audio_signal, length=None): + """ + Args: + audio_signal: [batch_size, input_dim, seq_len] + length: [batch_size] + Returns: + outputs: [batch_size, output_dim, seq_len//pooling_factor] + outputs_len: [batch_size] + """ + outputs = self.preprocess(audio_signal.transpose(1, 2)) + outputs = self.mlp(outputs) + outputs_len = torch.div(length, self.pooling_factor, rounding_mode='floor') + return outputs.transpose(1, 2), outputs_len diff --git a/nemo/collections/multimodal/speech_llm/modules/perception_modules.py b/nemo/collections/multimodal/speech_llm/modules/perception_modules.py new file mode 100644 index 000000000000..2f0565982941 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/modules/perception_modules.py @@ -0,0 +1,431 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from typing import List, Optional, Tuple + +import torch +import torch.distributed +import torch.nn as nn +from omegaconf import DictConfig + +from nemo.collections.asr.models import EncDecSpeakerLabelModel +from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder, ConformerMultiLayerFeatureExtractor +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import align_feat_seq_list +from nemo.core.classes import Exportable, NeuralModule +from nemo.core.classes.common import typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType +from nemo.utils.decorators import experimental + + +__all__ = ["AudioPerceptionModule", "MultiAudioPerceptionModule"] + + +class AudioPerceptionModule(NeuralModule, Exportable): + """Audio perception module that consists of audio encoder(s) and modality adapter.""" + + def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200): + batch_size = torch.randint(low=1, high=max_batch, size=[1]).item() + max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item() + signals = torch.rand(size=[batch_size, max_length]) * 2 - 1 + lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size]) + lengths[0] = max_length + return signals, lengths, None, None + + @property + def input_types(self): + """Returns definitions of module input ports.""" + return OrderedDict( + { + "input_signal": NeuralType(("B", "T"), AudioSignal(freq=self.preprocessor._sample_rate)), + "input_signal_length": NeuralType( + tuple("B"), LengthsType() + ), # Please note that length should be in samples not seconds. + "processed_signal": NeuralType(("B", "D", "T"), SpectrogramType()), + "processed_signal_length": NeuralType(tuple("B"), LengthsType()), + } + ) + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return OrderedDict( + { + "encoded": NeuralType(("B", "T", "D"), AcousticEncodedRepresentation()), + "encoded_len": NeuralType(tuple("B"), LengthsType()), + } + ) + + def __init__(self, cfg: DictConfig): + super().__init__() + # Initialize components + self.preprocessor = self.from_config_dict(cfg.preprocessor) + self.encoder = self.from_config_dict(cfg.encoder) + + if cfg.get("use_multi_layer_feat", False) and cfg.get("multi_layer_feat", None): + if "_target_" in cfg.multi_layer_feat.aggregator: + aggregator = self.from_config_dict(cfg.multi_layer_feat.aggregator) + else: + aggregator = MultiFeatureAggregator(cfg.multi_layer_feat.aggregator, channel_dim=1) + self.encoder = ConformerMultiLayerFeatureExtractor( + encoder=self.encoder, layer_idx_list=cfg.multi_layer_feat.layer_idx_list, aggregator=aggregator + ) + + if 'spec_augment' in cfg and cfg.spec_augment is not None: + self.spec_augmentation = self.from_config_dict(cfg.spec_augment) + else: + self.spec_augmentation = None + self.modality_adapter = self.from_config_dict(cfg.modality_adapter) + if 'output_dim' not in cfg.modality_adapter and "d_model" in cfg.modality_adapter: # e.g., conformer encoder + self.proj = nn.Linear(cfg.modality_adapter.d_model, cfg.output_dim) + else: + self.proj = nn.Identity() + + def maybe_preprocess_audio( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + ): + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) is False: + raise ValueError( + f"{self.__class__} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, + length=input_signal_length, + ) + return processed_signal, processed_signal_length + + # disable type checks to avoid type-check errors when using Conformer as modality adapter + @typecheck.disable_checks() + def forward( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + ): + processed_signal, processed_signal_length = self.maybe_preprocess_audio( + input_signal, input_signal_length, processed_signal, processed_signal_length + ) + + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + encoded, encoded_len = self.modality_adapter(audio_signal=encoded, length=encoded_len) + + # b, c, t -> b, t, c + encoded = self.proj(encoded.transpose(1, 2)) + + return encoded, encoded_len + + +class MultiFeatureAggregator(nn.Module): + """ + A module used to aggregate multiple encoded features (from different encoders or different layers) into a single feature sequence. + """ + + def __init__(self, cfg: DictConfig, channel_dim: int = 1): + super().__init__() + self.mode = cfg.get("mode", "cat") + self.channel_dim = channel_dim + self.pooling = cfg.get("pooling", "mean") + self.align_mode = cfg.get("align_mode", "min") + + def _have_same_length(self, encoded_len: List[torch.Tensor]) -> bool: + sample_len = encoded_len[0] + for x in encoded_len: + if torch.sum(x - sample_len) != 0: + return False + return True + + def forward( + self, + encoded: List[torch.Tensor], + encoded_len: List[torch.Tensor], + ref_idx: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if not self._have_same_length(encoded_len): + """Align the length of encoded features if they are different.""" + target_len = encoded[0].size(self.channel_dim) + if ref_idx is not None: + target_len = encoded[ref_idx].size(self.channel_dim) + if self.channel_dim != 1: + encoded = [x.transpose(1, self.channel_dim) for x in encoded] + encoded, encoded_len = align_feat_seq_list( + encoded, encoded_len, mode=self.align_mode, pooling=self.pooling, target_len=target_len + ) + if self.channel_dim != 1: + encoded = [x.transpose(1, self.channel_dim) for x in encoded] + + if self.mode == "cat": + return torch.cat(encoded, dim=self.channel_dim), encoded_len[0] + elif self.mode == "sum": + return torch([x.unsqueeze(-1) for x in encoded], dim=-1).sum(dim=-1), encoded_len[0] + elif self.mode == "mean" or self.mode == "avg": + return torch([x.unsqueeze(-1) for x in encoded], dim=-1).mean(dim=-1), encoded_len[0] + elif self.mode == "max": + return torch([x.unsqueeze(-1) for x in encoded], dim=-1).max(dim=-1), encoded_len[0] + elif self.mode == "min": + return torch([x.unsqueeze(-1) for x in encoded], dim=-1).min(dim=-1), encoded_len[0] + elif self.mode == "none": + return encoded, encoded_len + else: + raise ValueError(f"Unknown mode {self.mode}") + + +@experimental +class MultiAudioPerceptionModule(NeuralModule, Exportable): + """ + Audio perception module that consists of multiple audio encoders and shared modality adapter. + This module is experimental. An example perception cfg is: + ------------------- + perception: + modality_adapter: + _target_: nemo.collections.multimodal.speechllm.modules.PoolingMLPConnectors + hidden_dim: 512 + pooling: 'cat' + pooling_factor: 2 + num_layers: 4 + input_dim: -1 + output_dim: -1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoders: + asr_model: + _target_: nemo.collections.asr.models.ASRModel + output_key: d_model + freeze: True + pretrained_model: stt_en_fastconformer_transducer_large + ssl_model: + _target_: nemo.collections.asr.models.SpeechEncDecSelfSupervisedModel + output_key: d_model + freeze: True + pretrained_model: ssl_en_conformer_large + use_multi_layer_feat: True + multi_layer_feat: + layer_idx_list: [0,16] + aggregator: + mode: "cat" + pooling: "avg" + rounding: "floor" + + speaker_model: + segment_length_in_secs: 0.4 + freeze: True + pretrained_model: titanet_large + + ref_model: asr_model + aggregator: + mode: "cat" + pooling: "mean" + rounding: "floor" + ------------------- + """ + + def __init__(self, cfg: DictConfig): + super().__init__() + # Initialize components + self.aggregator = MultiFeatureAggregator(cfg.aggregator, channel_dim=1) + if 'spec_augment' in cfg and cfg.spec_augment is not None: + self.spec_augmentation = self.from_config_dict(cfg.spec_augment) + else: + self.spec_augmentation = None + + self.encoder_cfg = cfg.encoders + if not isinstance(self.encoder_cfg, DictConfig): + raise TypeError(f"cfg.encoders must be a DictConfig, got {type(cfg.encoders)}") + + preprocessor = {} + encoders = {} + for key, enc_cfg in self.encoder_cfg.items(): + encoder = self.from_config_dict(enc_cfg.model) + if enc_cfg.get("use_multi_layer_feat", False) and enc_cfg.get("multi_layer_feat", None): + if not isinstance(encoder, ConformerEncoder): + raise TypeError( + f"Encoder {key} must be a ConformerEncoder when use_multi_layer_feat is True, got {type(encoder)}" + ) + if "_target_" in enc_cfg.multi_layer_feat.aggregator: + aggregator = self.from_config_dict(enc_cfg.multi_layer_feat.aggregator) + else: + aggregator = MultiFeatureAggregator(enc_cfg.multi_layer_feat.aggregator, channel_dim=1) + encoder = ConformerMultiLayerFeatureExtractor( + encoder=encoder, layer_idx_list=enc_cfg.multi_layer_feat.layer_idx_list, aggregator=aggregator + ) + encoders[key] = encoder + preprocessor[key] = ( + self.from_config_dict(enc_cfg.get("preprocessor")) + if enc_cfg.get("preprocessor", None) is not None + else None + ) + self.encoders = nn.ModuleDict(encoders) + self.preprocessor = nn.ModuleDict(preprocessor) + + self.speaker_model = None + self.speaker_seg_len = None + if "speaker_model" in cfg and cfg.speaker_model.get("model", None) is not None: + self.speaker_model = EncDecSpeakerLabelModel(cfg=cfg.speaker_model.model) + self.speaker_model.spec_augmentation = self.spec_augmentation + self.speaker_seg_len = 1 + if "preprocessor" in cfg.speaker_model.model: + self.speaker_seg_len = int( + cfg.speaker_model.segment_length_in_secs // cfg.speaker_model.model.preprocessor.window_stride + ) + self.ref_model = cfg.get("ref_model", None) + if self.ref_model is not None: + if self.ref_model not in self.encoders and ( + self.ref_model != "speaker_model" and self.speaker_model is not None + ): + if self.ref_model == "speaker_model": + raise ValueError(f"ref_model is `{self.ref_model}` but speaker_model is None") + raise ValueError(f"ref_model `{self.ref_model}` not found in encoders [{encoders.keys()}]") + + self.modality_adapter = self.from_config_dict(cfg.modality_adapter) + if 'output_dim' not in cfg.modality_adapter and "d_model" in cfg.modality_adapter: # e.g., conformer encoder + self.proj = nn.Linear(cfg.modality_adapter.d_model, cfg.output_dim) + else: + self.proj = nn.Identity() + + def maybe_preprocess_audio( + self, + preprocessor, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + ): + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) is False: + raise ValueError( + f"{self.__class__} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal and preprocessor is not None: + processed_signal, processed_signal_length = preprocessor( + input_signal=input_signal, + length=input_signal_length, + ) + elif not has_processed_signal and preprocessor is None: + processed_signal, processed_signal_length = input_signal, input_signal_length + return processed_signal, processed_signal_length + + def forward_speaker( + self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None + ): + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) is False: + raise ValueError( + f"{self.__class__} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + if not has_processed_signal: + processed_signal, processed_signal_length = self.speaker_model.preprocessor( + input_signal=input_signal, + length=input_signal_length, + ) + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + # encoded has shape [B, D, T], length has shape [B] + encoded, encoded_len = self.speaker_model.encoder( + audio_signal=processed_signal, length=processed_signal_length + ) + + # pad encoded to be divisible by speaker_seg_len + if encoded.shape[2] % self.speaker_seg_len != 0: + encoded = torch.cat( + [ + encoded, + torch.zeros( + encoded.shape[0], + encoded.shape[1], + self.speaker_seg_len - encoded.shape[2] % self.speaker_seg_len, + device=encoded.device, + ), + ], + dim=2, + ) + + B, D, T = encoded.shape + num_seg = int(T // self.speaker_seg_len) + encoded = encoded.view(int(B * num_seg), D, self.speaker_seg_len) # [B*num_seg, D, seg_len] + encoded_len_seg = (encoded_len // self.speaker_seg_len).repeat_interleave(num_seg) # [B*seg_len] + + _, embeds = self.speaker_model.decoder(encoder_output=encoded, length=encoded_len_seg) + + embeds = embeds.view(B, -1, num_seg) # [B, D, num_seg] + + embeds_len = encoded_len // self.speaker_seg_len # [B] + return embeds, embeds_len + + def forward( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + ): + encoded_list = [] + encoded_len_list = [] + ref_idx = None + for key, encoder in self.encoders.items(): + curr_processed_signal, curr_processed_signal_length = self.maybe_preprocess_audio( + self.preprocessor[key], input_signal, input_signal_length, processed_signal, processed_signal_length + ) + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation( + input_spec=curr_processed_signal, length=curr_processed_signal_length + ) + encoded, encoded_len = encoder(audio_signal=curr_processed_signal, length=curr_processed_signal_length) + if key == self.ref_model: + ref_idx = len(encoded_list) + encoded_list.append(encoded) + encoded_len_list.append(encoded_len) + + if self.speaker_model is not None: + speaker_embeds, speaker_embeds_len = self.forward_speaker( + input_signal=input_signal, + input_signal_length=input_signal_length, + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + ) + encoded_list.append(speaker_embeds) + encoded_len_list.append(speaker_embeds_len) + encoded_list, encoded_len_list = self.aggregator( + encoded=encoded_list, encoded_len=encoded_len_list, ref_idx=ref_idx + ) + encoded, encoded_len = self.modality_adapter(audio_signal=encoded_list, length=encoded_len_list) + # b, c, t -> b, t, c + encoded = self.proj(encoded.transpose(1, 2)) + return encoded, encoded_len diff --git a/nemo/collections/multimodal/speech_llm/parts/__init__.py b/nemo/collections/multimodal/speech_llm/parts/__init__.py new file mode 100644 index 000000000000..d0c4b8bd282c --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/parts/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import ( + ceil_to_nearest, + get_num_samples_from_files, + maybe_cast_to_list, + shift_tokens_by_multi_audios, +) diff --git a/nemo/collections/multimodal/speech_llm/parts/mixins/__init__.py b/nemo/collections/multimodal/speech_llm/parts/mixins/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/parts/mixins/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/speech_llm/parts/mixins/adapter_mixin.py b/nemo/collections/multimodal/speech_llm/parts/mixins/adapter_mixin.py new file mode 100644 index 000000000000..6071bda87057 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/parts/mixins/adapter_mixin.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin, replace_prefix +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP, PEFTConfig +from nemo.utils import logging + + +class SpeechLLMAdapterMixin(NLPAdapterModelMixin): + def load_adapters( + self, + filepath: str, + peft_cfgs: Optional[Union[PEFTConfig, List[PEFTConfig]]] = None, + map_location: str = None, + ): + """ + Utility method that restores only the adapter module(s), and not the entire model itself. + This allows the sharing of adapters which are often just a fraction of the size of the full model, + enabling easier delivery. + + .. note:: + + During restoration, assumes that the model does not currently already have one or more adapter modules. + + Args: + filepath: Filepath of the .ckpt or .nemo file. + peft_cfgs: One or more PEFTConfig objects that specify the PEFT method configuration. + If none, will infer from the .nemo checkpoint + map_location: Pytorch flag, where to place the adapter(s) state dict(s). + """ + + # Determine device + if map_location is None: + if torch.cuda.is_available(): + map_location = 'cuda' + else: + map_location = 'cpu' + + if filepath.endswith('.nemo'): + conf, state_dict = self._get_config_and_state_dict_from_nemo(filepath, map_location) + elif filepath.endswith('.ckpt'): + state_dict = torch.load(filepath, map_location)['state_dict'] + else: + raise RuntimeError(f"{filepath} is not nemo file or ckpt file") + if not peft_cfgs: + assert filepath.endswith( + '.nemo' + ), "Inferring peft scheme is only supported for .nemo checkpoints. Please supply the `peft_cfgs` argument." + peft_cfgs = [PEFT_CONFIG_MAP[conf.peft.peft_scheme](conf)] + if self.cfg.megatron_amp_O2: + state_dict = {replace_prefix(k, 'model.', 'model.module.'): v for k, v in state_dict.items()} + self.add_adapter(peft_cfgs) + if not self.ptuning_only_and_non_first_stage: + target_keys = self.adapter_keys.union(self.tunable_base_param_keys) + if set(state_dict.keys()) != target_keys: + logging.warning( + f"Unexpected keys found in state_dict: {set(state_dict.keys()) - target_keys}, missing keys in state_dict: {target_keys - set(state_dict.keys())}" + ) + super(MegatronGPTModel, self).load_state_dict(state_dict, strict=False) diff --git a/nemo/collections/multimodal/speech_llm/parts/utils/__init__.py b/nemo/collections/multimodal/speech_llm/parts/utils/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/parts/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py new file mode 100644 index 000000000000..92a3548f9337 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import numpy as np +import torch + + +def maybe_cast_to_list(x): + if isinstance(x, np.ndarray): + return [item.tolist() for item in x] + return x + + +def ceil_to_nearest(n, m): + return (n + m - 1) // m * m + + +def get_num_samples_from_files(file_list): + if isinstance(file_list, str): + file_list = file_list.split(',') + num_samples = [] + for file in file_list: + with open(file, 'r') as f: + lines = list(f.readlines()) + num = len(lines) + if lines[-1] == '\n': + num -= 1 + num_samples.append(num) + return num_samples + + +def shift_tokens_by_multi_audios( + context_tokens, context_lengths, audio_feat_lens, context_start_idx, encoder_max_length +): + """ + split and shift the context tokens by the audio segments, then concatenate them back. This function assumes that the whole context + starts and ends with text tokens, and the audio segments are in between the text tokens. The audio segments are not allowed to be adjacent to each other. + Args: + context_tokens: tensor of shape [batch, max_context_len] + context_lengths: tensor of shape [batch,] + audio_feat_lens: List[List[int]] + context_start_idx: List[List[int]] + encoder_max_length: int + """ + new_context_tokens = [] + for i in range(context_tokens.shape[0]): + start_idx_list_i = context_start_idx[i] + [context_lengths[i]] + input_len_list = [start_idx_list_i[j + 1] - start_idx_list_i[j] for j in range(len(start_idx_list_i) - 1)] + context_tokens_list = context_tokens[i][: context_lengths[i]].split(input_len_list) + context_tokens_i = [context_tokens_list[0]] + for j in range(1, len(context_tokens_list)): + context_tokens_i.append( + torch.zeros(audio_feat_lens[i][j - 1], dtype=torch.long, device=context_tokens.device) + ) + context_tokens_i.append(context_tokens_list[j]) + context_tokens_i = torch.cat(context_tokens_i) + context_tokens_i = torch.nn.functional.pad( + context_tokens_i, (0, encoder_max_length - context_tokens_i.shape[0]) + ) + new_context_tokens.append(context_tokens_i) + new_context_tokens = torch.stack(new_context_tokens) + return new_context_tokens + + +def get_nested_dict_value(d, key, sep="."): + """ + Get the value of a nested dict given a key + Args: + d: dict + key: str + """ + for k in key.split(sep): + d = d[k] + return d + + +def align_feat_seq_list( + seq_list: List[torch.Tensor], + seq_len_list: List[torch.Tensor], + mode: str = "min", + pooling: str = 'mean', + target_len: Optional[int] = None, +): + """ + Align a list of feature sequences to the same length by repeating or discarding frames. + Args: + seq_list: List[torch.Tensor], list of tensors of shape [batch, hidden_size, seq_len] + seq_len_list: List[torch.Tensor], list of tensors of shape [batch,] + mode: str, "min" or "max" + pooling: str, "mean", "max", or "min" + Returns: + new_seq_list: List[torch.Tensor], list of tensors of shape [batch, hidden_size, new_seq_len] + new_seq_len_list: List[torch.Tensor], list of tensors of shape [batch,] + """ + MODES = ["min", "max"] + if mode not in MODES: + raise ValueError(f"mode {mode} not supported, available modes: {MODES}") + POOLING = ["mean", "max", "min", "avg"] + if pooling not in POOLING: + raise ValueError(f"pooling {pooling} not supported, available modes: {POOLING}") + + new_seq_len_list = [] + new_seq_list = [] + + if target_len is None: + target_len = [x.size(-1) for x in seq_list] + target_len = min(target_len) if mode == "min" else max(target_len) + + for seq, seq_len in zip(seq_list, seq_len_list): + curr_len = seq.size(-1) + if curr_len > target_len: + ratio = round(curr_len / target_len) + res = abs(ratio * target_len - curr_len) + if ratio * target_len > curr_len: # e.g., ratio = 1.9 + # repeat the last res frames + seq = torch.cat([seq, seq[:, :, -res:]], dim=-1) + seq_len += res * (seq_len > target_len).long() + elif ratio * target_len < curr_len: # e.g., ratio = 2.1 + # discard the last res frames + seq = seq[:, :, :-res] + seq_len -= res * (seq_len > target_len).long() + new_seq = seq.reshape(seq.size(0), seq.size(1), ratio, target_len) + if pooling == "min": + new_seq = new_seq.min(dim=2) + elif pooling == "max": + new_seq = new_seq.max(dim=2) + else: + new_seq = new_seq.mean(dim=2) + new_seq_len = torch.round(seq_len / ratio).long() + else: # curr_len <= target_len + ratio = round(target_len / curr_len) + res = abs(ratio * curr_len - target_len) + new_seq = torch.repeat_interleave(seq, ratio, dim=-1) + new_seq_len = seq_len * ratio + if ratio * curr_len > target_len: # e.g., ratio = 1.9 + new_seq = new_seq[:, :, :target_len] + new_seq_len = ( + seq_len * ratio - (ratio * seq_len - target_len) * (ratio * seq_len > target_len).long() + ) # subtract additional frames + elif ratio * curr_len < target_len: # e.g., ratio = 2.1 + new_seq = torch.cat([new_seq, seq[:, :, -res:]], dim=-1) + new_seq_list.append(new_seq) + new_seq_len_list.append(new_seq_len) + return new_seq_list, new_seq_len_list diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ea56429f4de1..536fc5bff7c8 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -174,7 +174,7 @@ def forward(self, **kwargs): the superclass by the square root of the hidden size specified in the configuration. """ embeddings = super().forward(**kwargs) - return embeddings * torch.tensor(self.config.hidden_size ** 0.5, dtype=embeddings.dtype) + return embeddings * torch.tensor(self.config.hidden_size**0.5, dtype=embeddings.dtype) class MegatronGPTExportableModel(torch.nn.Module, Exportable): @@ -196,11 +196,14 @@ def __init__(self, model): def forward(self, tokens, position_ids, attention_mask): if self.fp8_enabled and HAVE_TE: - with transformer_engine.pytorch.onnx_export(self.fp8_enabled), transformer_engine.pytorch.fp8_autocast( - enabled=self.fp8_enabled, fp8_recipe=self.fp8_recipe - ), torch.no_grad(), torch.inference_mode(), torch.autocast( - 'cuda', dtype=self.dtype - ), warnings.catch_warnings(): + with ( + transformer_engine.pytorch.onnx_export(self.fp8_enabled), + transformer_engine.pytorch.fp8_autocast(enabled=self.fp8_enabled, fp8_recipe=self.fp8_recipe), + torch.no_grad(), + torch.inference_mode(), + torch.autocast('cuda', dtype=self.dtype), + warnings.catch_warnings(), + ): warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*') assert tokens.shape == position_ids.shape assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1] @@ -211,9 +214,12 @@ def forward(self, tokens, position_ids, attention_mask): labels=None, ) else: - with torch.no_grad(), torch.inference_mode(), torch.autocast( - 'cuda', dtype=self.dtype - ), warnings.catch_warnings(): + with ( + torch.no_grad(), + torch.inference_mode(), + torch.autocast('cuda', dtype=self.dtype), + warnings.catch_warnings(), + ): warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*') assert tokens.shape == position_ids.shape assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1] @@ -509,7 +515,7 @@ def setup_optimizer_param_groups(self): self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model) def setup_mcore_distributed_parallel(self): - """Set up mcore distributed data parallel """ + """Set up mcore distributed data parallel""" if self.with_distributed_adam and self.use_mcore_dist_optim: config = get_model_config(self.model[0]) ddp_config = DistributedDataParallelConfig( @@ -641,7 +647,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): if self.validation_param_sync_overlap: param_sync_func = self.sync_overlap_parameters elif not self.use_mcore_dist_optim: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) grad_sync_func = self.reduce_overlap_gradients param_sync_func = self.sync_overlap_parameters else: @@ -744,9 +753,9 @@ def training_step_fwd_bwd_step_call(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - We pass the dataloader iterator function to the micro-batch scheduler. - The input batch to each micro-batch is fetched using the dataloader function - in the micro-batch fwd function. + We pass the dataloader iterator function to the micro-batch scheduler. + The input batch to each micro-batch is fetched using the dataloader function + in the micro-batch fwd function. """ # Initialize userbuffer communicators. if self.initialize_ub: @@ -877,7 +886,11 @@ def training_step(self, dataloader_iter): if self.log_memory_usage: mem_reserved = torch.cuda.max_memory_reserved() self.log( - 'peak_memory_usage', mem_reserved, prog_bar=True, rank_zero_only=True, batch_size=1, + 'peak_memory_usage', + mem_reserved, + prog_bar=True, + rank_zero_only=True, + batch_size=1, ) ## logging @@ -901,20 +914,29 @@ def training_step(self, dataloader_iter): lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr, rank_zero_only=True, batch_size=1) self.log( - 'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1, + 'global_step', + self.trainer.global_step, + prog_bar=True, + rank_zero_only=True, + batch_size=1, ) consumed_samples = self._compute_consumed_samples_after_training_step() # TODO: make sure compute_consumed_samples works for pipeline parallelism self.log( - 'consumed_samples', consumed_samples, prog_bar=True, rank_zero_only=True, batch_size=1, + 'consumed_samples', + consumed_samples, + prog_bar=True, + rank_zero_only=True, + batch_size=1, ) if self.rampup_batch_size: self.prev_global_batch_size = current_global_batch_size self.prev_consumed_samples = consumed_samples num_microbatch_calculator.update( - consumed_samples=consumed_samples, consistency_check=False, + consumed_samples=consumed_samples, + consistency_check=False, ) current_global_batch_size = num_microbatch_calculator.current_global_batch_size self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1) @@ -923,20 +945,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. + No need to call it here. """ return def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ return def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) or getattr( @@ -954,9 +976,9 @@ def _append_sequence_parallel_module_grads(self, module, grads): grads.append(grad.data) def allreduce_sequence_parallel_gradients(self): - """ All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. - Modified from megatron-lm: - https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 + """All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. + Modified from megatron-lm: + https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 """ grads = [] @@ -974,8 +996,7 @@ def allreduce_sequence_parallel_gradients(self): buf.copy_(synced) def allreduce_fsdp_sharding_omitted_gradients(self): - """ All-reduce gradients of FSDP-sharding-omitted parameters in sharding domain (data-parallel domain). - """ + """All-reduce gradients of FSDP-sharding-omitted parameters in sharding domain (data-parallel domain).""" assert isinstance(self.model, torch.nn.Module) grads = [] for param in self.model.parameters(): @@ -1022,16 +1043,16 @@ def allreduce_first_last_embeddings(self): torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) def _make_data_iterator_list(self, data_iterator: Iterator) -> List[Iterator]: - """ Convert data iterator into form expected by Megatron - - With interleaved pipeline parallelism, Megatron expects a - list of one data iterator per model chunk. Each model - chunk independently gets data from its data iterator, so - we need to interact with the data iterator multiple times - for each microbatch step. Instead of incorporating this - logic into the data loader, we cache the iterator's output - to the first model chunk and reuse it in the other model - chunks. + """Convert data iterator into form expected by Megatron + + With interleaved pipeline parallelism, Megatron expects a + list of one data iterator per model chunk. Each model + chunk independently gets data from its data iterator, so + we need to interact with the data iterator multiple times + for each microbatch step. Instead of incorporating this + logic into the data loader, we cache the iterator's output + to the first model chunk and reuse it in the other model + chunks. """ if not isinstance(self.model, list) or len(self.model) == 1: @@ -1159,7 +1180,10 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys.update(('labels', 'loss_mask')) if self.get_attention_mask_from_fusion and 'attention_mask' in required_keys: required_keys.remove('attention_mask') - batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + batch = { + key: val.cuda(non_blocking=True) if key in required_keys and isinstance(val, torch.Tensor) else None + for key, val in batch.items() + } # slice batch along sequence dimension for context parallelism batch = self.get_batch_on_this_context_parallel_rank(batch) @@ -1323,10 +1347,10 @@ def id_func(output_tensor): def validation_step(self, dataloader_iter, dataloader_idx=0): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ mode = 'test' if self.trainer.testing else 'val' # Initialize userbuffer communicators. @@ -1387,7 +1411,9 @@ def on_validation_epoch_end(self): if self.loss_broadcast_src_rank is None: self.loss_broadcast_src_rank = parallel_state.get_pipeline_model_parallel_last_rank() torch.distributed.broadcast( - averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), + averaged_loss, + self.loss_broadcast_src_rank, + group=parallel_state.get_pipeline_model_parallel_group(), ) self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) @@ -1492,7 +1518,10 @@ def build_train_valid_test_datasets(self): dataset_type = MockGPTDataset if mock_dataset else GPTDataset self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder( - dataset_type, train_valid_test_num_samples, is_dataset_built_on_rank, dataset_config, + dataset_type, + train_valid_test_num_samples, + is_dataset_built_on_rank, + dataset_config, ).build() if self._train_ds is not None: @@ -1702,16 +1731,16 @@ def list_available_models(self): return None def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( @@ -1788,9 +1817,9 @@ def on_load_checkpoint(self, checkpoint) -> None: def on_validation_model_zero_grad(self) -> None: """ - Skip gradient zeroing at the beginning of validation routine. - This is needed when overlapping the AllGather of the updated parameters with the following valdation step. - """ + Skip gradient zeroing at the beginning of validation routine. + This is needed when overlapping the AllGather of the updated parameters with the following valdation step. + """ if not self.validation_param_sync_overlap: super().on_validation_model_zero_grad() @@ -1859,9 +1888,9 @@ def initialize_last_rank_embeddings(self): parallel_state.set_virtual_pipeline_model_parallel_rank(0) def _reset_activation_checkpointing_args(self): - """ Disables activation checkpointing completely and saves the values so that - _restore_activation_checkpointing_args can restore them later. This function must always be - called before _restore_activation_checkpointing_args. + """Disables activation checkpointing completely and saves the values so that + _restore_activation_checkpointing_args can restore them later. This function must always be + called before _restore_activation_checkpointing_args. """ # Store values to restore them later. self.last_activations_checkpoint_granularity = self.cfg.activations_checkpoint_granularity @@ -1888,9 +1917,9 @@ def _reset_activation_checkpointing_args(self): module.language_model.encoder.activations_checkpoint_layers_per_pipeline = None def _restore_activation_checkpointing_args(self): - """ Restores the activation checkpointing parameters using the values saved by - _reset_activation_checkpointing_args. This function must never be called before - _reset_activation_checkpointing_args. + """Restores the activation checkpointing parameters using the values saved by + _reset_activation_checkpointing_args. This function must never be called before + _reset_activation_checkpointing_args. """ # Restore config values. self.cfg.activations_checkpoint_granularity = self.last_activations_checkpoint_granularity @@ -1917,9 +1946,9 @@ def _restore_activation_checkpointing_args(self): ) def _reset_sequence_parallelism_args(self): - """ Disables sequence parallelism completely and saves the values so that - _restore_sequence_parallelism_args can restore them later. This function must always be - called before _restore_sequence_parallelism_args. + """Disables sequence parallelism completely and saves the values so that + _restore_sequence_parallelism_args can restore them later. This function must always be + called before _restore_sequence_parallelism_args. """ # Store values to restore them later. self.last_sequence_parallel = self.cfg.sequence_parallel @@ -1936,9 +1965,9 @@ def _reset_sequence_parallelism_args(self): mod.sequence_parallel = False def _restore_sequence_parallelism_args(self): - """ Restores the sequence parallelism parameters using the values saved by - _reset_sequence_parallelism_args. This function must never be called before - _reset_sequence_parallelism_args. + """Restores the sequence parallelism parameters using the values saved by + _reset_sequence_parallelism_args. This function must never be called before + _reset_sequence_parallelism_args. """ # Restore config values. self.cfg.sequence_parallel = self.last_sequence_parallel @@ -1952,10 +1981,10 @@ def _restore_sequence_parallelism_args(self): mod.sequence_parallel = self.last_sequence_parallel def build_transformer_config(self) -> TransformerConfig: - """ Builds the megatron core gpt transformer config for the model. - For attributes in the nemo model config that are the same - as the megatron core TransformerConfig, we will use the value from the nemo model config. - For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. + """Builds the megatron core gpt transformer config for the model. + For attributes in the nemo model config that are the same + as the megatron core TransformerConfig, we will use the value from the nemo model config. + For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. """ normalization = self.cfg.get('normalization', 'layernorm').lower() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index d7a5cf3f26bf..1b59b90d2968 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -354,7 +354,7 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): token_count_avg = sum(batch['token_count']) / len(batch['token_count']) # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() - batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} + batch = {k: v for k, v in batch.items() if isinstance(v, (torch.Tensor, list))} _, seq_length = batch['tokens'].shape data_iter = get_iterator_k_split(batch, get_num_microbatches()) @@ -367,7 +367,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): grad_sync_func = None param_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) grad_sync_func = self.reduce_overlap_gradients param_sync_func = self.sync_overlap_parameters @@ -855,13 +858,19 @@ def setup_training_dataloader(self): if hasattr(self, '_train_ds'): consumed_samples = self.compute_consumed_samples(0) self._train_dl = self.build_data_loader( - dataset=self._train_ds, data_cfg=self.cfg.data.train_ds, consumed_samples=consumed_samples, + dataset=self._train_ds, + data_cfg=self.cfg.data.train_ds, + consumed_samples=consumed_samples, ) def setup_eval_dataloader(self, datasets, data_cfg): dataloaders = [] for dataset in datasets: - eval_dl = self.build_data_loader(dataset=dataset, data_cfg=data_cfg, consumed_samples=0,) + eval_dl = self.build_data_loader( + dataset=dataset, + data_cfg=data_cfg, + consumed_samples=0, + ) dataloaders.append(eval_dl) return dataloaders diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 48234459453e..75c50146bfab 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -22,6 +22,8 @@ from torch import Tensor +from nemo.utils import logging, logging_mode + try: from apex.normalization import MixedFusedRMSNorm from apex.normalization.fused_layer_norm import FusedLayerNorm # NOQA @@ -310,9 +312,7 @@ def make_inference_attention_mask_3d(source_block, target_block, pad_id): def make_inference_history_mask_3d(block): batch, length = block.shape arange = torch.arange(length, device=block.device) - history_mask = (arange[None,] <= arange[:, None])[ - None, - ] + history_mask = (arange[None,] <= arange[:, None])[None,] history_mask = history_mask.expand(batch, length, length) return history_mask @@ -413,14 +413,56 @@ def get_all_params_for_weight_decay_optimization( return tuple(filter(lambda g: len(g['params']) > 0, param_groups)) -def get_iterator_k_split(batch: List[torch.Tensor], num_microbatches: int) -> Iterator: +def split_list(inputs, num_chunks): + """ + Split a list into equal sized chunks + """ + chunk_size = len(inputs) // num_chunks + assert len(inputs) % chunk_size == 0, "Issue with batch size configuration!" + return [inputs[i : i + chunk_size] for i in range(0, len(inputs), chunk_size)] + + +def get_iterator_k_split(batch: Union[Dict, List[torch.Tensor]], num_microbatches: int) -> Iterator: + """ + Split a batch into k microbatches, where the batch size is divisible by k. Batch could be + a dictionary of tensors or a list of tensors. A dictionary batch could also have items of List type, + as long as the length of that list is the same as the batch size. + """ if isinstance(batch, dict): - items = list(batch.items()) + discard_items = [k for k, v in batch.items() if not isinstance(v, (torch.Tensor, list))] + if len(discard_items) > 0: + logging.warning( + f"Only support splitting torch.Tensor and List[torch.Tensor]. Discarding the following keys from the batch: {discard_items}", + mode=logging_mode.ONCE, + ) + + batch = {k: v for k, v in batch.items() if isinstance(v, (torch.Tensor, list))} + tensor_items = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} + list_items = {k: v for k, v in batch.items() if isinstance(v, list)} + + # Split tensor items + items = list(tensor_items.items()) assert items[0][1].shape[0] % num_microbatches == 0, "Issue with batch size configuration!" split_batch = [torch.tensor_split(item[1], num_microbatches, dim=0) for item in items] - microbatches = [[(items[i][0], split_batch[i][j]) for i in range(len(items))] for j in range(num_microbatches)] + + if len(list_items) == 0: + # Only have tensor items + microbatches = [ + [(items[i][0], split_batch[i][j]) for i in range(len(items))] for j in range(num_microbatches) + ] + else: + # Split list items + list_items = list(list_items.items()) + split_list_batch = [split_list(item[1], num_microbatches) for item in list_items] + # Merge tensor and list items + all_keys = [item[0] for item in items] + [item[0] for item in list_items] + all_split_batch = split_batch + split_list_batch + microbatches = [ + [(all_keys[i], all_split_batch[i][j]) for i in range(len(all_keys))] for j in range(num_microbatches) + ] microbatches = [dict(elem) for elem in microbatches] else: + # Split a list of torch tensors assert batch[0].shape[0] % num_microbatches == 0, "Issue with batch size configuration!" split_batch = [ torch.tensor_split(item, num_microbatches, dim=0) if torch.is_tensor(item) else item for item in batch diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index cf39ed134768..97757b2e3826 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -219,7 +219,10 @@ def _validate_input_types(self, input_types=None, ignore_collections=False, **kw hasattr(value, 'neural_type') and is_semantic_typecheck_enabled() and not metadata.base_types[key].compare(value.neural_type) - in (NeuralTypeComparisonResult.SAME, NeuralTypeComparisonResult.GREATER,) + in ( + NeuralTypeComparisonResult.SAME, + NeuralTypeComparisonResult.GREATER, + ) ): error_msg = [ f"{input_types[key].compare(value.neural_type)} :", @@ -398,7 +401,10 @@ def __check_neural_type(self, obj, metadata: TypecheckMetadata, depth: int, name hasattr(obj, 'neural_type') and is_semantic_typecheck_enabled() and not type_val.compare(obj.neural_type) - in (NeuralTypeComparisonResult.SAME, NeuralTypeComparisonResult.GREATER,) + in ( + NeuralTypeComparisonResult.SAME, + NeuralTypeComparisonResult.GREATER, + ) ): raise TypeError( f"{type_val.compare(obj.neural_type)} : \n" @@ -711,6 +717,7 @@ def from_pretrained( return_config: bool = False, trainer: Optional['Trainer'] = None, save_restore_connector: SaveRestoreConnector = None, + return_model_file: Optional[bool] = False, ): """ Instantiates an instance of NeMo from NVIDIA NGC cloud @@ -726,6 +733,7 @@ def from_pretrained( strict: Passed to torch.load_state_dict. By default true. return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model. + return_model_file: If set to true, will return just the downloaded model file in cache Returns: A model instance of a particular model class or its underlying config (if return_config is set). @@ -751,6 +759,9 @@ def from_pretrained( model_name=model_name, refresh_cache=refresh_cache ) + if return_model_file: + return nemo_model_file_in_cache + instance = class_.restore_from( restore_path=nemo_model_file_in_cache, override_config_path=override_config_path, diff --git a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py index 690010ad29ca..f0c7847b8c9b 100644 --- a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +++ b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py @@ -124,7 +124,11 @@ ) parser.add_argument( - "--metadata_path", required=False, default=None, type=str, help="Path to metadata file for the dataset.", + "--metadata_path", + required=False, + default=None, + type=str, + help="Path to metadata file for the dataset.", ) parser.add_argument( @@ -165,7 +169,10 @@ ) parser.add_argument( - "--buckets_num", type=int, default=1, help="Number of buckets to create based on duration.", + "--buckets_num", + type=int, + default=1, + help="Number of buckets to create based on duration.", ) parser.add_argument( @@ -617,6 +624,15 @@ def _read_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): with open(manifest_path, 'r', encoding='utf-8') as m: for line in m: entry = json.loads(line) + audio_key = "audio_filepath" if "audio_filepath" in entry else "audio_file" + if audio_key not in entry: + raise KeyError(f"Manifest entry does not contain 'audio_filepath' or 'audio_file' key: {entry}") + audio_filepath = entry[audio_key] + if not os.path.isfile(audio_filepath) and not os.path.isabs(audio_filepath): + audio_filepath_abs = os.path.join(os.path.dirname(manifest_path), audio_filepath) + if not os.path.isfile(audio_filepath_abs): + raise FileNotFoundError(f"Could not find {audio_filepath} or {audio_filepath_abs}!") + entry[audio_key] = audio_filepath_abs if (config.max_duration is None or entry['duration'] < config.max_duration) and ( config.min_duration is None or entry['duration'] >= config.min_duration ): @@ -648,8 +664,7 @@ def _write_to_tar(self, tar, audio_filepath: str, squashed_filename: str) -> Non tar.addfile(ti, encoded_audio) def _create_shard(self, entries, target_dir, shard_id, manifest_folder): - """Creates a tarball containing the audio files from `entries`. - """ + """Creates a tarball containing the audio files from `entries`.""" if self.config.sort_in_shards: entries.sort(key=lambda x: x["duration"], reverse=False) diff --git a/tests/collections/multimodal/test_speechllm_models.py b/tests/collections/multimodal/test_speechllm_models.py new file mode 100644 index 000000000000..8698fed205ea --- /dev/null +++ b/tests/collections/multimodal/test_speechllm_models.py @@ -0,0 +1,266 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import pytorch_lightning as pl +import torch +from megatron.core import parallel_state +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.plugins.environments import TorchElasticEnvironment + +from nemo.collections.multimodal.speech_llm.models import modular_models +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import shift_tokens_by_multi_audios +from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy + + +class ModularAudioGPTModel(modular_models.ModularAudioGPTModel): + # disable logging to avoid MisconfigurationException + def log(self, *args, **kwargs): + pass + + +def setup_module(): + pl.seed_everything(1) + # init model parallel needed for LLM loss + init_method = 'tcp://' + master_ip = 'localhost' + master_port = '6000' + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group(backend='gloo', world_size=1, rank=0, init_method=init_method) + parallel_state.initialize_model_parallel(1, 1) + + +@pytest.fixture +def llm_model_config(): + this_test_dir = os.path.dirname(os.path.abspath(__file__)) + # Although most of the stuff in model is loaded from ckpt, we need configs + # for e.g. cfg.model.optim + config = OmegaConf.load( + os.path.join( + this_test_dir, + "../../../examples/multimodal/speech_llm/conf/modular_audio_gpt_config_peft.yaml", + ) + ) + # TODO(zhehuai): move the following to Test /home/TestData + config.model.restore_from_path = "/root/home/works/TestData/pretrained_models/megatron_gpt/gpt_pretrain_220m_len_4096_pos_alibi_step_595508_gbs256.nemo" + config.model.micro_batch_size = 2 + config.model.global_batch_size = 2 + config.model.data.validation_ds.manifest_filepath = ( + '/root/home/works/TestData/datasets/LibriSpeech/dev_clean_cleaned.json' + ) + config.model.data.train_ds.manifest_filepath = ( + '/root/home/works/TestData/datasets/LibriSpeech/dev_clean_cleaned.json' + ) + return config + + +@pytest.fixture +def trainer_config(): + config_trainer = DictConfig({}) + + if torch.cuda.is_available(): + accelerator = "gpu" + torch.set_default_device('cuda') + else: + accelerator = "cpu" + config_trainer.accelerator = accelerator + config_trainer.devices = 1 + config_trainer.num_nodes = 1 + config_trainer.max_epochs = 4 + config_trainer.max_steps = 1 + config_trainer.val_check_interval = 1.0 + + # for PyTorch Native AMP set precision=16 + config_trainer.precision = 32 + + # setup cluster environment parameters" + # use torch elastic cluster environment so `create_process_externally` is True + # the launcher is set to None. It will not try to spawn new processes. + # It won't create the misconfiguration error because of the `interactive session` + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + strategy = NLPDDPStrategy() + plugins = [TorchElasticEnvironment()] + trainer = pl.Trainer(logger=False, plugins=plugins, strategy=strategy, **config_trainer) + return trainer, config_trainer + + +@pytest.fixture +def perception_model_config(): + preprocessor = {"_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor"} + encoder = { + "_target_": "nemo.collections.asr.modules.ConformerEncoder", + "feat_in": 64, + "n_layers": 8, + "d_model": 64, + "self_attention_model": "rel_pos_local_attn", + "att_context_size": [128, 128], + } + + model_config = DictConfig( + { + "_target_": "nemo.collections.multimodal.speechllm.modules.speechllm_perception.AudioPerceptionModule", + "preprocessor": DictConfig(preprocessor), + "encoder": DictConfig(encoder), + "modality_adapter": DictConfig(encoder), + "output_dim": 1024, + } + ) + return model_config + + +@pytest.fixture +def test_batch(): + signal_len = torch.from_numpy(np.array([64000, 64000])) + transcript = torch.arange(10).reshape(2, 5).int() + tokens = transcript[:, :-1] + labels = transcript[:, 1:] + transcript_length = torch.Tensor([3, 2]).int() + # assuming context_lengths = [1, 1] + loss_mask = torch.Tensor([[0, 1, 1, 0], [0, 1, 0, 0]]) + batch = { + 'audio_signal_length': signal_len, + 'tokens': tokens, + 'tokens_length': transcript_length, + 'contexts': torch.arange(260).reshape(2, 130).int(), + 'context_lengths': torch.Tensor([1, 1]).int(), + 'labels': labels, + 'answers': labels, + 'loss_mask': loss_mask, + } + batch['audio_signal'] = torch.randn([2, 64000]) + return batch + + +@pytest.mark.skip(reason="nedd to move pretrained GPT model to /home/works/TestData first") +class TestModularAudioGPTModel: + @pytest.mark.unit + def test_init_and_train(self, llm_model_config, perception_model_config, trainer_config): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" + llm_model_config.model.perception = perception_model_config + trainer, llm_model_config.trainer = trainer_config + model = ModularAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) + + assert isinstance(model.model, GPTModel) + with tempfile.TemporaryDirectory() as tmpdir: + save_path = str(Path(tmpdir) / "model.nemo") + model.train() + model.save_to(save_path) + + @pytest.mark.unit + def test_prepare_llm_input(self, llm_model_config, perception_model_config, trainer_config, test_batch): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" + llm_model_config.model.perception = perception_model_config + trainer, llm_model_config.trainer = trainer_config + model = ModularAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) + model.cuda() + model.train() + batch = {key: val.cuda(non_blocking=True) for key, val in test_batch.items()} + encoder_input, attention_mask, labels, loss_mask, encoder_length = model.prepare_llm_input(batch) + assert encoder_input.shape == (17, 2, 768) + assert np.allclose(encoder_input.sum().cpu().detach().numpy(), 15.783691) + assert attention_mask.shape == (2, 1, 17, 17) + assert labels.shape == (2, 17) + assert np.allclose(loss_mask.sum(axis=1).cpu().numpy(), [2, 1]) + assert np.allclose(encoder_length.cpu().numpy(), (16, 15)) + + @pytest.mark.unit + def test_training_step(self, llm_model_config, perception_model_config, trainer_config, test_batch): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" + llm_model_config.model.perception = perception_model_config + trainer, llm_model_config.trainer = trainer_config + model = ModularAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) + model.cuda() + model.on_train_start() + model.setup() + model.train() + loss_mean = model.training_step(iter([test_batch]), None) + assert np.allclose(loss_mean.cpu().detach().numpy(), 5.7052) + + @pytest.mark.unit + def test_validation_step(self, llm_model_config, perception_model_config, trainer_config, test_batch): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" + llm_model_config.model.perception = perception_model_config + trainer, llm_model_config.trainer = trainer_config + model = ModularAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) + model.cuda() + model.train() + batch = {key: val.cuda(non_blocking=True) for key, val in test_batch.items()} + loss_mean = model.validation_step(iter([batch]), 0) + assert np.allclose(loss_mean['loss'].cpu().detach().numpy(), 5.7052) + + @pytest.mark.unit + def test_predict_step(self, llm_model_config, perception_model_config, trainer_config, test_batch): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" + llm_model_config.model.perception = perception_model_config + trainer, llm_model_config.trainer = trainer_config + model = ModularAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) + model.cuda() + model.train() + batch = {key: val.cuda(non_blocking=True) for key, val in test_batch.items()} + response = model.predict_step(batch, 0, 0) + ground_truth = 'to suit you. Please note these are lecture notes from an alternate presentation. Copyright ⁇ ' + assert response['sentences'][0] == ground_truth + + @pytest.mark.unit + def test_concat_multi_features(self, llm_model_config, perception_model_config, trainer_config): + llm_model_config.model.pretrained_audio_model = "stt_en_fastconformer_transducer_large" + llm_model_config.model.perception = perception_model_config + trainer, llm_model_config.trainer = trainer_config + model = ModularAudioGPTModel.restore_from_pretrained_models(llm_model_config, trainer=trainer) + model.eval() + + feat_dim = 32 + encoded = [torch.ones([3, 16, feat_dim]), torch.ones([3, 16, feat_dim])] + encoded_len = [torch.LongTensor([12, 8, 4]), torch.LongTensor([12, 8, 4])] + input_embeds = torch.zeros([2, 32, feat_dim]) + input_length = torch.LongTensor([32, 28]) + context_start_idx = [[0, 4, 12, 20], [0, 8, 16, 25]] + encoder_input, encoder_length = model._concat_multi_features( + encoded, encoded_len, input_embeds, input_length, context_start_idx + ) + assert encoder_input.shape == (2, 56, feat_dim) # max audio_len + text_len = (12 + 8 + 4) + 32 = 56 + assert encoder_length.shape == (2,) + assert np.allclose(encoder_length.cpu().numpy(), (56, 52)) + assert encoder_input[0, : context_start_idx[0][1]].sum() == 0 # first 4 features are text features + assert np.allclose( + encoder_input[0, context_start_idx[0][1] : context_start_idx[0][1] + encoded_len[0][0]], + torch.ones([encoded_len[0][0], feat_dim]), + ) + + @pytest.mark.unit + def test_shift_tokens_by_multi_audios(self): + """This test is put here because its functionality is similar to _concat_multi_features()""" + encoder_max_length = 64 + audio_len = [torch.LongTensor([12, 8, 4]), torch.LongTensor([12, 8, 4])] + context_tokens = torch.ones([2, 32]) + context_length = torch.LongTensor([32, 28]) + context_start_idx = [[0, 4, 12, 20], [0, 8, 16, 25]] + new_context_tokens = shift_tokens_by_multi_audios( + context_tokens, context_length, audio_len, context_start_idx, encoder_max_length + ) + assert new_context_tokens.shape == (2, 64) + assert np.allclose(new_context_tokens[0, : context_start_idx[0][1]], torch.ones([context_start_idx[0][1]])) + assert np.allclose( + new_context_tokens[0, context_start_idx[0][1] : context_start_idx[0][1] + audio_len[0][0]], + torch.zeros([audio_len[0][0]]), + )