Skip to content

Commit

Permalink
improve code to support all decoder types
Browse files Browse the repository at this point in the history
Signed-off-by: Nithin Rao Koluguri <nithinraok>
  • Loading branch information
Nithin Rao Koluguri committed Aug 28, 2023
1 parent dbd64dd commit bf35b0b
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 74 deletions.
63 changes: 39 additions & 24 deletions docs/source/asr/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,8 @@ Main parts of the config:
batch_size: 16 # you may increase batch_size if your memory allows
# other params
Finetuning
~~~~~~~~~~~
Finetuning with Text-Only Data
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To finetune existing ASR model using text-only data use ``<NeMo_git_root>/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py`` script with the corresponding config ``<NeMo_git_root>/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml``.

Expand Down Expand Up @@ -1030,47 +1030,53 @@ Fine-tuning Configurations
All ASR scripts support easy fine-tuning by partially/fully loading the pretrained weights from a checkpoint into the **currently instantiated model**. Note that the currently instantiated model should have parameters that match the pre-trained checkpoint (such that weights may load properly). In order to directly fine-tune a pre-existing checkpoint, please follow the tutorial `ASR Language Fine-tuning. <https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb>`_
Pre-trained weights can be provided in multiple ways -
Models can be fine-tuned in two ways:
* By updating or retaining current tokenizer alone
* By updating model architecture and tokenizer
Fine-tuning by updating or retaining current tokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In this case, the model architecture is not updated. The model is initialized with the pre-trained weights by
two ways:
1) Providing a path to a NeMo model (via ``init_from_nemo_model``)
2) Providing a name of a pretrained NeMo model (which will be downloaded via the cloud) (via ``init_from_pretrained_model``)
3) Providing a path to a Pytorch Lightning checkpoint file (via ``init_from_ptl_ckpt``)
There are multiple ASR subtasks inside the ``examples/asr/`` directory, you can substitute the ``<subtask>`` tag below.
Then users can use existing tokenizer or update the tokenizer with new vocabulary. This is useful when users don't want to update the model architecture
but want to update the tokenizer with new vocabulary.
Fine-tuning via a NeMo model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The same script can be used to finetune CTC, RNNT or Hybrid models as well.
<NeMo_repo>/examples/asr/speech_to_text_finetune.py script supports this type of fine-tuning with the following arguments:
.. code-block:: sh
python examples/asr/<subtask>/script_to_<script_name>.py \
python examples/asr/speech_to_text_finetune.py \
--config-path=<path to dir of configs> \
--config-name=<name of config without .yaml>) \
model.train_ds.manifest_filepath="<path to manifest file>" \
model.validation_ds.manifest_filepath="<path to manifest file>" \
model.tokenizer.update_tokenizer=<True/False> \ # True to update tokenizer, False to retain existing tokenizer
model.tokenizer.dir=<path to tokenizer dir> \ # Path to tokenizer dir when update_tokenizer=True
model.tokenizer.type=<tokenizer type> \ # tokenizer type when update_tokenizer=True
trainer.devices=-1 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
+init_from_nemo_model="<path to .nemo model file>"
+init_from_nemo_model="<path to .nemo model file>" (or +init_from_pretrained_model="<name of pretrained checkpoint>")
Fine-tuning by changing model architecture and tokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fine-tuning via a NeMo pretrained model name
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If users want to update the model architecture as well they can use the following script:
.. code-block:: sh
For providing pretrained model, users can provide Pre-trained weights in multiple ways -
python examples/asr/<subtask>/script_to_<script_name>.py \
--config-path=<path to dir of configs> \
--config-name=<name of config without .yaml>) \
model.train_ds.manifest_filepath="<path to manifest file>" \
model.validation_ds.manifest_filepath="<path to manifest file>" \
trainer.devices=-1 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
+init_from_pretrained_model="<name of pretrained checkpoint>"
1) Providing a path to a NeMo model (via ``init_from_nemo_model``)
2) Providing a name of a pretrained NeMo model (which will be downloaded via the cloud) (via ``init_from_pretrained_model``)
3) Providing a path to a Pytorch Lightning checkpoint file (via ``init_from_ptl_ckpt``)
Fine-tuning via a Pytorch Lightning checkpoint
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
There are multiple ASR subtasks inside the ``examples/asr/`` directory, you can substitute the ``<subtask>`` tag below.
.. code-block:: sh
Expand All @@ -1082,7 +1088,16 @@ Fine-tuning via a Pytorch Lightning checkpoint
trainer.devices=-1 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
+init_from_ptl_ckpt="<name of pytorch lightning checkpoint>"
+init_from_nemo_model="<path to .nemo model file>" # (or +init_from_pretrained_model, +init_from_ptl_ckpt )
To reinitialize part of the model, to make it different from the pretrained model, users can mention them through config:
.. code-block:: yaml
init_from_nemo_model: "<path to .nemo model file>"
asr_model:
include: ["preprocessor","encoder"]
exclude: ["decoder"]
Fine-tuning Execution Flow Diagram
----------------------------------
Expand Down
15 changes: 7 additions & 8 deletions examples/asr/conf/speech_to_text_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ name: "Speech_To_Text_Finetuning"

# use `init_from_nemo_model` or `init_from_pretrained_model` to initialize the model
# We do not currently support `init_from_ptl_ckpt` to create a single script for all types of models.
init_from_nemo_model: null
init_from_pretrained_model: null
init_from_nemo_model: null # path to nemo model

model:
sample_rate: 16000
Expand All @@ -19,7 +18,7 @@ model:
shuffle: true
num_workers: 8
pin_memory: true
max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
max_duration: 20
min_duration: 0.1
# tarred datasets
is_tarred: false
Expand Down Expand Up @@ -61,7 +60,7 @@ model:

optim:
name: adamw
lr: 5e-3
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3
Expand All @@ -70,14 +69,14 @@ model:
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 15000
warmup_steps: 5000
warmup_ratio: null
min_lr: 5e-4
min_lr: 5e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 500
max_epochs: 50
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
Expand Down Expand Up @@ -105,7 +104,7 @@ exp_manager:
monitor: "val_wer"
mode: "min"
save_top_k: 5
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints
always_save_nemo: True # saves the checkpoints as nemo files along with PTL checkpoints
resume_if_exists: false
resume_ignore_no_checkpoint: false

Expand Down
97 changes: 55 additions & 42 deletions examples/asr/speech_to_text_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
"""

import copy

import pytorch_lightning as pl
from omegaconf import OmegaConf
from pytorch_lightning.utilities import rank_zero_only
Expand All @@ -62,6 +60,59 @@
from nemo.utils.exp_manager import exp_manager


@rank_zero_only
def get_base_model(cfg):
asr_model = None
nemo_model_path = cfg.get('init_from_nemo_model', None)
pretrained_name = cfg.get('init_from_pretrained_model', None)
if nemo_model_path is not None and pretrained_name is not None:
raise ValueError("Only pass `init_from_nemo_model` or `init_from_pretrained_model` but not both")
elif nemo_model_path is None and pretrained_name is None:
raise ValueError(
"Both `init_from_nemo_model` and `init_from_pretrained_model cannot be None, should pass atleast one of them"
)
elif nemo_model_path is not None:
asr_model = ASRModel.restore_from(restore_path=nemo_model_path)
elif pretrained_name is not None:
asr_model = ASRModel.from_pretrained(model_name=pretrained_name)

return asr_model


def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type):
vocab_size = asr_model.tokenizer.vocab_size
decoder = asr_model.decoder.state_dict()
if hasattr(asr_model, 'joint'):
joint_state = asr_model.joint.state_dict()
else:
joint_state = None

if tokenizer_dir is None:
raise ValueError("dir must be specified if update_tokenizer is True")
logging.info("Using the tokenizer provided through config")
asr_model.change_vocabulary(new_tokenizer_dir=tokenizer_dir, new_tokenizer_type=tokenizer_type)
if asr_model.tokenizer.vocab_size != vocab_size:
logging.warning(
"The vocabulary size of the new tokenizer differs from that of the loaded model. As a result, finetuning will proceed with the new vocabulary, and the decoder will be reinitialized."
)
else:
asr_model.decoder.load_state_dict(decoder)
if joint_state is not None:
asr_model.joint.load_state_dict(joint_state)

return asr_model


def setup_dataloaders(asr_model, cfg):
cfg = model_utils.convert_model_config_to_dict_config(cfg)
asr_model.setup_training_data(cfg.model.train_ds)
asr_model.setup_validation_data(cfg.model.validation_ds)
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
asr_model.setup_test_data(cfg.model.test_ds)

return asr_model


@hydra_runner(config_path="conf", config_name="speech_to_text_finetune")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
Expand All @@ -74,54 +125,16 @@ def main(cfg):
"Currently for simplicity of single script for all model types, we only support `init_from_nemo_model` and `init_from_pretrained_model`"
)

@rank_zero_only
def get_base_model(cfg):
asr_model = None
nemo_model_path = cfg.get('init_from_nemo_model', None)
pretrained_name = cfg.get('init_from_pretrained_model', None)
if nemo_model_path is not None and pretrained_name is not None:
raise ValueError("Only pass `init_from_nemo_model` or `init_from_pretrained_model` but not both")
elif nemo_model_path is None and pretrained_name is None:
raise ValueError(
"Both `init_from_nemo_model` and `init_from_pretrained_model cannot be None, should pass atleast one of them"
)
elif nemo_model_path is not None:
asr_model = ASRModel.restore_from(restore_path=nemo_model_path)
elif pretrained_name is not None:
asr_model = ASRModel.from_pretrained(model_name=pretrained_name)

return asr_model

asr_model = get_base_model(cfg)
vocab_size = asr_model.tokenizer.vocab_size

# if new tokenizer is provided, use it
if hasattr(cfg.model.tokenizer, 'update_tokenizer') and cfg.model.tokenizer.update_tokenizer:
decoder = copy.deepcopy(asr_model.decoder)
joint_state = copy.deepcopy(asr_model.joint)

if cfg.model.tokenizer.dir is None:
raise ValueError("dir must be specified if update_tokenizer is True")
logging.info("Using the tokenizer provided through config")
asr_model.change_vocabulary(
new_tokenizer_dir=cfg.model.tokenizer.dir, new_tokenizer_type=cfg.model.tokenizer.type
)
if asr_model.tokenizer.vocab_size != vocab_size:
logging.warning(
"The vocabulary size of the new tokenizer differs from that of the loaded model. As a result, finetuning will proceed with the new vocabulary, and the decoder will be reinitialized."
)
else:
asr_model.decoder = decoder
asr_model.joint = joint_state
asr_model = update_tokenizer(asr_model, cfg.model.tokenizer.dir, cfg.model.tokenizer.type)
else:
logging.info("Reusing the tokenizer from the loaded model.")

# Setup Data
cfg = model_utils.convert_model_config_to_dict_config(cfg)
asr_model.setup_training_data(cfg.model.train_ds)
asr_model.setup_validation_data(cfg.model.validation_ds)
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
asr_model.setup_test_data(cfg.model.test_ds)
asr_model = setup_dataloaders(asr_model, cfg)

# Setup Optimizer
asr_model.setup_optimization(cfg.model.optim)
Expand Down

0 comments on commit bf35b0b

Please sign in to comment.