Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for finetuning with huggingface datasets #7834

Merged
merged 18 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ pipeline {
stage('Speech To Text Finetuning') {
steps {
sh 'python examples/asr/speech_to_text_finetune.py \
--config-path="conf" --config-name="speech_to_text_finetune" \
--config-path="conf/asr_finetune" --config-name="speech_to_text_finetune" \
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
init_from_nemo_model=/home/TestData/asr/stt_en_fastconformer_transducer_large.nemo \
Expand All @@ -207,6 +207,33 @@ pipeline {
}
}

stage('Speech To Text HF Finetuning') {
steps {
sh 'python examples/asr/speech_to_text_finetune.py \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets set all batch sizes to 1 for train and val

--config-path="conf/asr_finetune" --config-name="speech_to_text_hf_finetune" \
~model.train_ds.hf_data_cfg \
model.train_ds.streaming=true \
+model.train_ds.hf_data_cfg.path="librispeech_asr" \
+model.train_ds.hf_data_cfg.name=null \
+model.train_ds.hf_data_cfg.split="test.clean" \
+model.train_ds.hf_data_cfg.streaming=true \
~model.validation_ds.hf_data_cfg \
model.validation_ds.streaming=true \
+model.validation_ds.hf_data_cfg.path="librispeech_asr" \
+model.validation_ds.hf_data_cfg.name=null \
+model.validation_ds.hf_data_cfg.split="test.clean" \
+model.validation_ds.hf_data_cfg.streaming=true \
~model.test_ds \
init_from_nemo_model=/home/TestData/asr/stt_en_fastconformer_transducer_large.nemo \
model.tokenizer.update_tokenizer=False \
trainer.devices=[1] \
trainer.accelerator="gpu" \
+trainer.fast_dev_run=True \
exp_manager.exp_dir=examples/asr/speech_finetuning_results'
sh 'rm -rf examples/asr/speech_finetuning_results'
}
}

// TODO: Please Fix Me
// Error locating target 'nemo.collections.asr.modules.wav2vec_modules.ConvFeatureEncoder', see chained exception above.
// stage('L2: Speech Pre-training - Wav2Vec') {
Expand Down
12 changes: 12 additions & 0 deletions docs/source/asr/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,18 @@ The same script can be used to finetune CTC, RNNT or Hybrid models as well.
trainer.max_epochs=50 \
+init_from_nemo_model="<path to .nemo model file>" (or +init_from_pretrained_model="<name of pretrained checkpoint>")


Refer to <NeMo_repo>/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml for more details.

Finetune ASR Models using HuggingFace Datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Users can utilize HuggingFace Datasets for finetuning NeMo ASR models. The following config file can be used for this purpose:
`<NeMo_repo>/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml`

As mentioned earlier, users can update the tokenizer or use an existing one based on their requirements. If users want to create a new tokenizer
from HuggingFace Datasets, they can use the following script:
`<NeMo_repo>/scripts/tokenizers/get_hf_text_data.py`

Fine-tuning by changing model architecture and tokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
189 changes: 189 additions & 0 deletions examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
name: "Speech_To_Text_HF_Finetuning_using_HF_Datasets"

# use `init_from_nemo_model` or `init_from_pretrained_model` to initialize the model
# We do not currently support `init_from_ptl_ckpt` to create a single script for all types of models.
init_from_nemo_model: null # path to nemo model
init_from_pretrained_model: null # name of pretrained NeMo model, e.g., `stt_en_fastconformer_transducer_large`

model:
sample_rate: 16000
compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
log_prediction: true # enables logging sample predictions in the output during training
rnnt_reduction: 'mean_volume'
skip_nan_grad: false

# configs for huggingface load_dataset function
data_path: "librispeech_asr"
data_name: null # name for the specific dataset to load, e.g., 'en' for MCV datasets, but some datasets don't require this field.
streaming: false # set True to use streaming mode, which doesn't wait for data downloading but each training step takes longer in the first epoch. If True, you'll need to specify trainer.max_steps instead of trainer.max_epochs.

# keys for audio, sample_rate and transcription in the huggingface dataset, keys seperated by `.` for nested fields. See example at the bottom of this file.
audio_key: "audio.array"
sample_rate_key: "audio.sampling_rate"
text_key: "text" # the key for groundtruth transcription, e.g., MCV usually uses "sentence" while some others use "text"

# simple text cleaning, by default converts all chars to lower-case and only keeps alpha-numeric chars.
normalize_text: true
symbols_to_keep: ["'"] # a list of symbols to keep during text cleaning.

train_ds:
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code
streaming: ${model.streaming}
normalize_text: ${model.normalize_text}
symbols_to_keep: ${model.symbols_to_keep}
audio_key: ${model.audio_key}
sample_rate_key: ${model.sample_rate_key}
text_key: ${model.text_key}
hf_data_cfg: # hf_data_cfg can be a ListConfig or DictConfig. Params for each data are passed into huggingface load_dataset(). Add more params if needed
- path: ${model.data_path}
name: ${model.data_name}
split: 'train.clean.360'
streaming: ${model.streaming}
- path: ${model.data_path}
name: ${model.data_name}
split: 'train.clean.100'
streaming: ${model.streaming}
- path: ${model.data_path}
name: ${model.data_name}
split: 'train.other.500'
streaming: ${model.streaming}

sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
shuffle_n: 2048
num_workers: 8
pin_memory: true
use_start_end_token: false

validation_ds:
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code
streaming: ${model.streaming}
normalize_text: ${model.normalize_text}
symbols_to_keep: ${model.symbols_to_keep}
audio_key: ${model.audio_key}
sample_rate_key: ${model.sample_rate_key}
text_key: ${model.text_key}
hf_data_cfg: # An example of using only one dataset
path: ${model.data_path}
name: ${model.data_name}
split: 'validation.other'
streaming: ${model.streaming}

sample_rate: ${model.sample_rate}
batch_size: 8
shuffle: false
shuffle_n: 2048
num_workers: 8
pin_memory: true
use_start_end_token: false

test_ds:
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code
streaming: ${model.streaming}
normalize_text: ${model.normalize_text}
symbols_to_keep: ${model.symbols_to_keep}
audio_key: ${model.audio_key}
sample_rate_key: ${model.sample_rate_key}
text_key: ${model.text_key}
hf_data_cfg: # hf_data_cfg can be a ListConfig or DictConfig. Params for each data are passed into huggingface load_dataset(). Add more params if needed
- path: ${model.data_path}
name: ${model.data_name}
split: 'test.other'
streaming: ${model.streaming}
- path: ${model.data_path}
name: ${model.data_name}
split: 'test.clean'
streaming: ${model.streaming}

sample_rate: ${model.sample_rate}
batch_size: 8
shuffle: false
shuffle_n: 2048
num_workers: 8
pin_memory: true
use_start_end_token: false

char_labels: # use for char based models
update_labels: false
labels: null # example list config: \[' ', 'a', 'b', 'c'\]

tokenizer: # use for spe/bpe based tokenizer models
update_tokenizer: false
dir: null # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

optim:
name: adamw
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 5000
warmup_ratio: null
min_lr: 5e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 100
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 32 # 16, 32, or bf16
log_every_n_steps: 10 # Interval of logging.
enable_progress_bar: True
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training


exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_wer"
mode: "min"
save_top_k: 5
always_save_nemo: True # saves the checkpoints as nemo files along with PTL checkpoints
resume_if_exists: false
resume_ignore_no_checkpoint: false

create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null


# An example item in the HuggingFace `librispeech_asr` dataset:
# {'chapter_id': 141231,
# 'file': '/home/patrick/.cache/huggingface/datasets/downloads/extracted/b7ded9969e09942ab65313e691e6fc2e12066192ee8527e21d634aca128afbe2/dev_clean/1272/141231/1272-141231-0000.flac',
# 'audio': {
# 'path': '/home/patrick/.cache/huggingface/datasets/downloads/extracted/b7ded9969e09942ab65313e691e6fc2e12066192ee8527e21d634aca128afbe2/dev_clean/1272/141231/1272-141231-0000.flac',
# 'array': array([-0.00048828, -0.00018311, -0.00137329, ..., 0.00079346, 0.00091553, 0.00085449], dtype=float32),
# 'sampling_rate': 16000
# },
# 'id': '1272-141231-0000',
# 'speaker_id': 1272,
# 'text': 'A MAN SAID TO THE UNIVERSE SIR I EXIST'}
2 changes: 1 addition & 1 deletion examples/asr/speech_to_text_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def setup_dataloaders(asr_model, cfg):
return asr_model


@hydra_runner(config_path="conf", config_name="speech_to_text_finetune")
@hydra_runner(config_path="conf/asr_finetune", config_name="speech_to_text_finetune")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

Expand Down
14 changes: 14 additions & 0 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from torch.utils.data import ChainDataset

from nemo.collections.asr.data import audio_to_text, audio_to_text_dali
from nemo.collections.asr.data.huggingface.hf_audio_to_text_dataset import (
get_hf_audio_to_text_bpe_dataset,
get_hf_audio_to_text_char_dataset,
)
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.common.data.dataset import CodeSwitchedDataset, ConcatDataset
from nemo.utils import logging
Expand Down Expand Up @@ -598,6 +602,11 @@ def get_audio_to_text_char_dataset_from_config(
else:
augmentor = None

if 'hf_data_cfg' in config:
return get_hf_audio_to_text_char_dataset(
config=config, global_rank=global_rank, world_size=world_size, augmentor=augmentor
)

is_concat = config.get('is_concat', False)
if is_concat:
if 'concat_sampling_technique' in config and config['concat_sampling_technique'] is None:
Expand Down Expand Up @@ -722,6 +731,11 @@ def get_audio_to_text_bpe_dataset_from_config(
else:
augmentor = None

if 'hf_data_cfg' in config:
return get_hf_audio_to_text_bpe_dataset(
config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer, augmentor=augmentor
)

is_concat = config.get('is_concat', False)
if is_concat:
if 'concat_sampling_technique' in config and config['concat_sampling_technique'] is None:
Expand Down
13 changes: 13 additions & 0 deletions nemo/collections/asr/data/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading