diff --git a/Jenkinsfile b/Jenkinsfile index 6351d8d07302..bc9bba2765ea 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -134,7 +134,7 @@ pipeline { } } - stage('L2: Speech to Text WPE - CitriNet') { + stage('Speech to Text WPE - CitriNet') { steps { sh 'python examples/asr/asr_ctc/speech_to_text_ctc_bpe.py \ --config-path="../conf/citrinet/" --config-name="config_bpe" \ @@ -150,7 +150,7 @@ pipeline { } } - stage('L2: Speech Pre-training - CitriNet') { + stage('Speech Pre-training - CitriNet') { steps { sh 'python examples/asr/speech_pretraining/speech_pre_training.py \ --config-path="../conf/ssl/citrinet/" --config-name="citrinet_ssl_ci" \ @@ -164,6 +164,22 @@ pipeline { } } + stage('Speech To Text Finetuning') { + steps { + sh 'python examples/asr/speech_to_text_finetune.py \ + --config-path="conf" --config-name="speech_to_text_finetune" \ + model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ + model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ + init_from_nemo_model=/home/TestData/asr/stt_en_fastconformer_transducer_large.nemo \ + model.tokenizer.update_tokenizer=False \ + trainer.devices=[1] \ + trainer.accelerator="gpu" \ + +trainer.fast_dev_run=True \ + exp_manager.exp_dir=examples/asr/speech_finetuning_results' + sh 'rm -rf examples/asr/speech_finetuning_results' + } + } + // TODO: Please Fix Me // Error locating target 'nemo.collections.asr.modules.wav2vec_modules.ConvFeatureEncoder', see chained exception above. // stage('L2: Speech Pre-training - Wav2Vec') { diff --git a/docs/source/asr/configs.rst b/docs/source/asr/configs.rst index d21b40e34570..30c8a74c5176 100644 --- a/docs/source/asr/configs.rst +++ b/docs/source/asr/configs.rst @@ -984,8 +984,8 @@ Main parts of the config: batch_size: 16 # you may increase batch_size if your memory allows # other params -Finetuning -~~~~~~~~~~~ +Finetuning with Text-Only Data +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To finetune existing ASR model using text-only data use ``/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py`` script with the corresponding config ``/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml``. @@ -1030,47 +1030,53 @@ Fine-tuning Configurations All ASR scripts support easy fine-tuning by partially/fully loading the pretrained weights from a checkpoint into the **currently instantiated model**. Note that the currently instantiated model should have parameters that match the pre-trained checkpoint (such that weights may load properly). In order to directly fine-tune a pre-existing checkpoint, please follow the tutorial `ASR Language Fine-tuning. `_ -Pre-trained weights can be provided in multiple ways - +Models can be fine-tuned in two ways: +* By updating or retaining current tokenizer alone +* By updating model architecture and tokenizer + +Fine-tuning by updating or retaining current tokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In this case, the model architecture is not updated. The model is initialized with the pre-trained weights by +two ways: 1) Providing a path to a NeMo model (via ``init_from_nemo_model``) 2) Providing a name of a pretrained NeMo model (which will be downloaded via the cloud) (via ``init_from_pretrained_model``) -3) Providing a path to a Pytorch Lightning checkpoint file (via ``init_from_ptl_ckpt``) -There are multiple ASR subtasks inside the ``examples/asr/`` directory, you can substitute the ```` tag below. +Then users can use existing tokenizer or update the tokenizer with new vocabulary. This is useful when users don't want to update the model architecture +but want to update the tokenizer with new vocabulary. -Fine-tuning via a NeMo model -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The same script can be used to finetune CTC, RNNT or Hybrid models as well. + +/examples/asr/speech_to_text_finetune.py script supports this type of fine-tuning with the following arguments: .. code-block:: sh - python examples/asr//script_to_.py \ + python examples/asr/speech_to_text_finetune.py \ --config-path= \ --config-name=) \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \ + model.tokenizer.update_tokenizer= \ # True to update tokenizer, False to retain existing tokenizer + model.tokenizer.dir= \ # Path to tokenizer dir when update_tokenizer=True + model.tokenizer.type= \ # tokenizer type when update_tokenizer=True trainer.devices=-1 \ trainer.accelerator='gpu' \ trainer.max_epochs=50 \ - +init_from_nemo_model="" + +init_from_nemo_model="" (or +init_from_pretrained_model="") +Fine-tuning by changing model architecture and tokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Fine-tuning via a NeMo pretrained model name -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +If users want to update the model architecture as well they can use the following script: -.. code-block:: sh +For providing pretrained model, users can provide Pre-trained weights in multiple ways - - python examples/asr//script_to_.py \ - --config-path= \ - --config-name=) \ - model.train_ds.manifest_filepath="" \ - model.validation_ds.manifest_filepath="" \ - trainer.devices=-1 \ - trainer.accelerator='gpu' \ - trainer.max_epochs=50 \ - +init_from_pretrained_model="" +1) Providing a path to a NeMo model (via ``init_from_nemo_model``) +2) Providing a name of a pretrained NeMo model (which will be downloaded via the cloud) (via ``init_from_pretrained_model``) +3) Providing a path to a Pytorch Lightning checkpoint file (via ``init_from_ptl_ckpt``) -Fine-tuning via a Pytorch Lightning checkpoint -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +There are multiple ASR subtasks inside the ``examples/asr/`` directory, you can substitute the ```` tag below. .. code-block:: sh @@ -1082,7 +1088,16 @@ Fine-tuning via a Pytorch Lightning checkpoint trainer.devices=-1 \ trainer.accelerator='gpu' \ trainer.max_epochs=50 \ - +init_from_ptl_ckpt="" + +init_from_nemo_model="" # (or +init_from_pretrained_model, +init_from_ptl_ckpt ) + +To reinitialize part of the model, to make it different from the pretrained model, users can mention them through config: + +.. code-block:: yaml + + init_from_nemo_model: "" + asr_model: + include: ["preprocessor","encoder"] + exclude: ["decoder"] Fine-tuning Execution Flow Diagram ---------------------------------- diff --git a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml index 9e3da8d3545f..10fa3750948d 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml @@ -254,7 +254,6 @@ trainer: precision: 32 # 16, 32, or bf16 log_every_n_steps: 10 # Interval of logging. enable_progress_bar: True - resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs sync_batchnorm: true diff --git a/examples/asr/conf/speech_to_text_finetune.yaml b/examples/asr/conf/speech_to_text_finetune.yaml new file mode 100644 index 000000000000..415172b33bb9 --- /dev/null +++ b/examples/asr/conf/speech_to_text_finetune.yaml @@ -0,0 +1,118 @@ +name: "Speech_To_Text_Finetuning" + +# use `init_from_nemo_model` or `init_from_pretrained_model` to initialize the model +# We do not currently support `init_from_ptl_ckpt` to create a single script for all types of models. +init_from_nemo_model: null # path to nemo model + +model: + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + rnnt_reduction: 'mean_volume' + skip_nan_grad: false + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + max_duration: 20 + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + use_start_end_token: false + num_workers: 8 + pin_memory: true + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + use_start_end_token: false + num_workers: 8 + pin_memory: true + + char_labels: # use for char based models + update_labels: false + labels: null # example list config: \[' ', 'a', 'b', 'c'\] + + tokenizer: # use for spe/bpe based tokenizer models + update_tokenizer: false + dir: null # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + optim: + name: adamw + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 5000 + warmup_ratio: null + min_lr: 5e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 50 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files along with PTL checkpoints + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/speech_to_text_finetune.py b/examples/asr/speech_to_text_finetune.py new file mode 100644 index 000000000000..a5ba95b41221 --- /dev/null +++ b/examples/asr/speech_to_text_finetune.py @@ -0,0 +1,202 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script can used to fine-tune a speech-to-text model of any instance type when users want to +fine-tune an existing model without changing its core architecture but may change the tokenizer. +One can mention the pretrained model in two ways: +1) `init_from_nemo_model` or +2) `init_from_pretrained_model` in the configuration. + +To update the model architecture in conjunction with other modifications, it is advisable to use the primary 'speech_to_text_rnnt/ctc_*.py' script. + +Note: To create a single script for all model types, we currently only support two types of +initializations: +1) `init_from_nemo_model`, and +2) `init_from_pretrained_model`, +but not `init_from_ptl_ckpt`. + +To train with prior base model tokenizer keep `model.tokenizer.update_tokenizer` as false else +make it true and provide tokenizer dir along with tokenizer type. + +To fine-tune the model, use the following commands: + +For initialization from a NEMO model: +```sh +python /examples/asr/speech_to_text_finetune.py \ + init_from_nemo_model= +``` + +For initialization from a pretrained model: +```sh +python /examples/asr/speech_to_text_finetune.py \ + init_from_pretrained_model= +``` + +# Fine-Tune a Model + +For documentation on fine-tuning this model, please visit: +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations +""" + +import pytorch_lightning as pl +from omegaconf import OmegaConf +from pytorch_lightning.utilities import rank_zero_only + +from nemo.collections.asr.models import ASRModel +from nemo.core.config import hydra_runner +from nemo.utils import logging, model_utils +from nemo.utils.exp_manager import exp_manager + + +@rank_zero_only +def get_base_model(cfg): + """ + Returns the base model to be fine-tuned. + Currently supports two types of initializations: + 1) `init_from_nemo_model`, and + 2) `init_from_pretrained_model`. + Args: + cfg: config + Returns: + asr_model: ASRModel instance + """ + asr_model = None + nemo_model_path = cfg.get('init_from_nemo_model', None) + pretrained_name = cfg.get('init_from_pretrained_model', None) + if nemo_model_path is not None and pretrained_name is not None: + raise ValueError("Only pass `init_from_nemo_model` or `init_from_pretrained_model` but not both") + elif nemo_model_path is None and pretrained_name is None: + raise ValueError( + "Both `init_from_nemo_model` and `init_from_pretrained_model cannot be None, should pass atleast one of them" + ) + elif nemo_model_path is not None: + asr_model = ASRModel.restore_from(restore_path=nemo_model_path) + elif pretrained_name is not None: + asr_model = ASRModel.from_pretrained(model_name=pretrained_name) + + return asr_model + + +def check_vocabulary(asr_model, cfg): + """ + Checks if the decoder and vocabulary of the model needs to be updated. + If either of them needs to be updated, it updates them and returns the updated model. + else vocabulary will be reused from the pre-trained model. + Args: + asr_model: ASRModel instance + cfg: config + Returns: + asr_model: ASRModel instance with updated decoder and vocabulary + """ + if hasattr(cfg.model.tokenizer, 'update_tokenizer') and cfg.model.tokenizer.update_tokenizer: + if hasattr(cfg.model.char_labels, 'update_labels') and cfg.model.char_labels.update_labels: + raise ValueError( + "Both `model.tokenizer.update_tokenizer` and `model.char_labels.update_labels` cannot be passed together" + ) + else: + asr_model = update_tokenizer(asr_model, cfg.model.tokenizer.dir, cfg.model.tokenizer.type) + elif hasattr(cfg.model, 'char_labels') and cfg.model.char_labels.update_labels: + asr_model.change_vocabulary(new_vocabulary=cfg.model.char_labels.labels) + logging.warning("The vocabulary of the model has been updated with provided char labels.") + else: + logging.info("Reusing the vocabulary from the pre-trained model.") + + return asr_model + + +def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type): + """ + Updates the tokenizer of the model and also reinitializes the decoder if the vocabulary size + of the new tokenizer differs from that of the loaded model. + Args: + asr_model: ASRModel instance + tokenizer_dir: tokenizer directory + tokenizer_type: tokenizer type + Returns: + asr_model: ASRModel instance with updated tokenizer and decoder + """ + vocab_size = asr_model.tokenizer.vocab_size + decoder = asr_model.decoder.state_dict() + if hasattr(asr_model, 'joint'): + joint_state = asr_model.joint.state_dict() + else: + joint_state = None + + if tokenizer_dir is None: + raise ValueError("dir must be specified if update_tokenizer is True") + logging.info("Using the tokenizer provided through config") + asr_model.change_vocabulary(new_tokenizer_dir=tokenizer_dir, new_tokenizer_type=tokenizer_type) + if asr_model.tokenizer.vocab_size != vocab_size: + logging.warning( + "The vocabulary size of the new tokenizer differs from that of the loaded model. As a result, finetuning will proceed with the new vocabulary, and the decoder will be reinitialized." + ) + else: + asr_model.decoder.load_state_dict(decoder) + if joint_state is not None: + asr_model.joint.load_state_dict(joint_state) + + return asr_model + + +def setup_dataloaders(asr_model, cfg): + """ + Sets up the training, validation and test dataloaders for the model. + Args: + asr_model: ASRModel instance + cfg: config + Returns: + asr_model: ASRModel instance with updated dataloaders + """ + cfg = model_utils.convert_model_config_to_dict_config(cfg) + asr_model.setup_training_data(cfg.model.train_ds) + asr_model.setup_multiple_validation_data(cfg.model.validation_ds) + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + asr_model.setup_multiple_test_data(cfg.model.test_ds) + + return asr_model + + +@hydra_runner(config_path="conf", config_name="speech_to_text_finetune") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + + if hasattr(cfg, 'init_from_ptl_ckpt') and cfg.init_from_ptl_ckpt is not None: + raise NotImplementedError( + "Currently for simplicity of single script for all model types, we only support `init_from_nemo_model` and `init_from_pretrained_model`" + ) + + asr_model = get_base_model(cfg) + + # Check vocabulary type and update if needed + asr_model = check_vocabulary(asr_model, cfg) + + # Setup Data + asr_model = setup_dataloaders(asr_model, cfg) + + # Setup Optimizer + asr_model.setup_optimization(cfg.model.optim) + + # Setup SpecAug + if hasattr(cfg.model, 'spec_augment') and cfg.model.spec_augment is not None: + asr_model.spec_augment = ASRModel.from_config_dict(cfg.model.spec_augment) + + trainer.fit(asr_model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter