diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md index d9dd850a5c..06795a2538 100644 --- a/examples/speech-recognition/README.md +++ b/examples/speech-recognition/README.md @@ -16,6 +16,22 @@ limitations under the License. # Automatic Speech Recognition Examples +## Table of Contents + +- [Automatic Speech Recognition with CTC](#connectionist-temporal-classification) + - [Single GPU example](#single-gpu-ctc) + - [Multi GPU example](#multi-gpu-ctc) + - [Examples](#examples-ctc) + - [TIMIT](#timit-ctc) + - [Librispeech](#librispeech-ctc) + - [Common Voice](#common-voice-ctc) + - [Multilingual Librispeech](#multilingual-librispeech-ctc) +- [Automatic Speech Recognition with Sequence-to-Sequence](#sequence-to-sequence) + - [Whisper Model](#whisper-model) + - [Fine tuning](#single-hpu-whisper-training-with-seq2seq) + - [Inference](#single-gpu-seq2seq-inference) + + ## Connectionist Temporal Classification The script [`run_speech_recognition_ctc.py`](https://github.com/huggingface/optimum-habana/tree/main/examples/speech-recognition/run_speech_recognition_ctc.py) can be used to fine-tune any pretrained [Connectionist Temporal Classification Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForCTC) for automatic speech recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset. @@ -187,3 +203,176 @@ python run_speech_recognition_ctc.py \ --gaudi_config_name="Habana/wav2vec2" \ --bf16 ``` +## Sequence to Sequence + +The script [`run_speech_recognition_seq2seq.py`](https://github.com/huggingface/optimum-habana/examples/speech-recognition/run_speech_recognition_seq2seq.py) can be used to fine-tune any [Whisper Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/whisper#whisper) for automatic speech +recognition on one of the well known speech recognition datasets similar to shown below or a custom dataset. Examples of two datasets using the Whisper model from OpenAI are included below. + +### Whisper Model +We can load all components of the Whisper model directly from the pretrained checkpoint, including the pretrained model weights, feature extractor and tokenizer. We simply have to specify our fine-tuning dataset and training hyperparameters. +```bash +# Set model dir and datasets dir as appropriate +export MODEL_DIR=$HOME/huggingface/hub +export DATASETS_DIR=$HOME/huggingface/datasets + +# Run this from the optimum-habana repository +cd examples/speech-recognition +# install optimum-habana +pip install -r requirements.txt +``` + +#### Single HPU Whisper Fine tuning with Seq2Seq +The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using a single GPU device in half-precision: +```bash +python run_speech_recognition_seq2seq.py \ + --model_name_or_path="openai/whisper-small" \ + --dataset_name="mozilla-foundation/common_voice_11_0" \ + --dataset_config_name="hi" \ + --language="hindi" \ + --train_split_name="train+validation" \ + --eval_split_name="test" \ + --gaudi_config_name="gaudi_config.json" \ + --max_steps="5000" \ + --output_dir="./results/whisper-small-hi" \ + --per_device_train_batch_size="16" \ + --gradient_accumulation_steps="2" \ + --per_device_eval_batch_size="16" \ + --logging_steps="25" \ + --learning_rate="1e-5" \ + --warmup_steps="500" \ + --evaluation_strategy="steps" \ + --eval_steps="1000" \ + --save_strategy="steps" \ + --save_steps="1000" \ + --generation_max_length="225" \ + --preprocessing_num_workers="16" \ + --length_column_name="input_length" \ + --max_duration_in_seconds="30" \ + --text_column_name="sentence" \ + --freeze_feature_encoder="False" \ + --gradient_checkpointing \ + --group_by_length \ + --bf16 \ + --overwrite_output_dir \ + --do_train \ + --do_eval \ + --predict_with_generate \ + --use_habana \ + --use_hpu_graphs_for_inference +``` + +If training on a different language, you should be sure to change the `language` and `dataset_config_name` arguments. + +#### Multi HPU Whisper Training with Seq2Seq +The following example shows how to fine-tune the [Whisper large](https://huggingface.co/openai/whisper-large) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 8 HPU devices in half-precision: +```bash +python ../gaudi_spawn.py \ + --world_size 8 --use_mpi run_speech_recognition_seq2seq.py \ + --model_name_or_path="openai/whisper-large" \ + --dataset_name="mozilla-foundation/common_voice_11_0" \ + --dataset_config_name="hi" \ + --language="hindi" \ + --train_split_name="train+validation" \ + --eval_split_name="test" \ + --gaudi_config_name="gaudi_config.json" \ + --max_steps="5000" \ + --output_dir="./results/whisper-large-hi" \ + --per_device_train_batch_size="16" \ + --gradient_accumulation_steps="2" \ + --per_device_eval_batch_size="16" \ + --logging_steps="25" \ + --learning_rate="1e-5" \ + --warmup_steps="500" \ + --evaluation_strategy="steps" \ + --eval_steps="1000" \ + --save_strategy="steps" \ + --save_steps="1000" \ + --generation_max_length="225" \ + --preprocessing_num_workers="16" \ + --length_column_name="input_length" \ + --max_duration_in_seconds="30" \ + --text_column_name="sentence" \ + --freeze_feature_encoder="False" \ + --gradient_checkpointing \ + --group_by_length \ + --bf16 \ + --overwrite_output_dir \ + --do_train \ + --do_eval \ + --predict_with_generate \ + --use_habana \ + --use_hpu_graphs_for_inference +``` + +#### Single GPU Seq2Seq Inference + +The following example shows how to do inference with the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [librispeech_asr_demo](https://huggingface.co/datasets/librispeech_asr_demo) using 1 HPU devices in half-precision: + +```bash +python run_speech_recognition_seq2seq.py \ + --model_name_or_path="openai/whisper-small" \ + --dataset_name="hf-internal-testing/librispeech_asr_demo" \ + --gaudi_config_name="gaudi_config.json" \ + --dataset_config_name="clean" \ + --eval_split_name="validation" \ + --max_steps="5000" \ + --output_dir="./results/whisper-small-clean" \ + --gradient_accumulation_steps="2" \ + --per_device_eval_batch_size="16" \ + --logging_steps="25" \ + --learning_rate="1e-5" \ + --warmup_steps="500" \ + --evaluation_strategy="steps" \ + --eval_steps="1000" \ + --save_strategy="steps" \ + --save_steps="1000" \ + --generation_max_length="225" \ + --preprocessing_num_workers="16" \ + --length_column_name="input_length" \ + --max_duration_in_seconds="30" \ + --freeze_feature_encoder="False" \ + --gradient_checkpointing \ + --group_by_length \ + --bf16 \ + --overwrite_output_dir \ + --do_eval \ + --predict_with_generate \ + --use_habana \ + --use_hpu_graphs_for_inference +``` +The following example shows how to do inference with the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 1 HPU devices in half-precision: + +```bash +python run_speech_recognition_seq2seq.py \ + --model_name_or_path="openai/whisper-small" \ + --dataset_name="mozilla-foundation/common_voice_11_0" \ + --dataset_config_name="hi" \ + --language="hindi" \ + --eval_split_name="test" \ + --gaudi_config_name="gaudi_config.json" \ + --max_steps="5000" \ + --output_dir="./results/whisper-small-clean" \ + --gradient_accumulation_steps="2" \ + --per_device_eval_batch_size="16" \ + --logging_steps="25" \ + --learning_rate="1e-5" \ + --warmup_steps="500" \ + --evaluation_strategy="steps" \ + --eval_steps="1000" \ + --save_strategy="steps" \ + --save_steps="1000" \ + --generation_max_length="225" \ + --preprocessing_num_workers="16" \ + --length_column_name="input_length" \ + --max_duration_in_seconds="30" \ + --text_column_name="sentence" \ + --freeze_feature_encoder="False" \ + --gradient_checkpointing \ + --group_by_length \ + --bf16 \ + --overwrite_output_dir \ + --do_eval \ + --predict_with_generate \ + --use_habana \ + --use_hpu_graphs_for_inference +``` diff --git a/examples/speech-recognition/gaudi_config.json b/examples/speech-recognition/gaudi_config.json new file mode 100644 index 0000000000..99480fb15b --- /dev/null +++ b/examples/speech-recognition/gaudi_config.json @@ -0,0 +1,3 @@ +{ + "use_torch_autocast": true, +} diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py new file mode 100755 index 0000000000..97f846588a --- /dev/null +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for sequence to sequence speech recognition. +""" +# You can also adapt this script on your own sequence to sequence speech +# recognition task. Pointers for this are left as comments. + +import logging +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import datasets +import evaluate +import torch +from datasets import DatasetDict, load_dataset + +import transformers +from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoModelForSpeechSeq2Seq, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, +) +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version + +from optimum.habana import GaudiConfig, GaudiSeq2SeqTrainer, GaudiSeq2SeqTrainingArguments +from optimum.habana.utils import set_seed + +try: + from optimum.habana.utils import check_optimum_habana_min_version +except ImportError: + + def check_optimum_habana_min_version(*a, **b): + return () + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.34.0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + token: str = field( + default=None, + metadata={ + "help": ( + "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " + "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." + ) + }, + ) + use_auth_token: bool = field( + default=None, + metadata={ + "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead." + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" + "should only be set to `True` for repositories you trust and in which you have read the code, as it will " + "execute code present on the Hub on your local machine." + ) + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + freeze_encoder: bool = field( + default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."} + ) + forced_decoder_ids: List[List[int]] = field( + default=None, + metadata={ + "help": ( + "A list of pairs of integers which indicates a mapping from generation indices to token indices " + "that will be forced before sampling. For example, [[0, 123]] means the first generated token " + "will always be a token of index 123." + ) + }, + ) + suppress_tokens: List[int] = field( + default=None, metadata={"help": "A list of tokens that will be suppressed at generation."} + ) + apply_spec_augment: bool = field( + default=False, + metadata={ + "help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": ( + "Truncate audio files that are longer than `max_duration_in_seconds` seconds to" + " 'max_duration_in_seconds`" + ) + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": ( + "Whether to only do data preprocessing and skip training. This is especially useful when data" + " preprocessing errors out in distributed training due to timeout. In this case, one should run the" + " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets" + " can consequently be loaded in distributed training" + ) + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="test", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + language: str = field( + default=None, + metadata={ + "help": ( + "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning " + "only. For English speech recognition, it should be set to `None`." + ) + }, + ) + task: str = field( + default="transcribe", + metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."}, + ) + + +@dataclass +class DataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`WhisperProcessor`]) + The processor used for processing the data. + decoder_start_token_id (`int`) + The begin-of-sentence of the decoder. + forward_attention_mask (`bool`) + Whether to return attention_mask. + """ + + processor: Any + decoder_start_token_id: int + forward_attention_mask: bool + + def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + model_input_name = self.processor.model_input_names[0] + input_features = [{model_input_name: feature[model_input_name]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") + + if self.forward_attention_mask: + batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features]) + + labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") + + # replace padding with -100 to ignore loss correctly + labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + + # if bos token is appended in previous tokenization step, + # cut bos token here as it's append later anyways + if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): + labels = labels[:, 1:] + + batch["labels"] = labels + + return batch + + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiSeq2SeqTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if model_args.use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.", + FutureWarning, + ) + if model_args.token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + model_args.token = model_args.use_auth_token + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args) + + # 2. Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + gaudi_config = GaudiConfig.from_pretrained( + training_args.gaudi_config_name, + cache_dir=model_args.cache_dir, + # use_auth_token=True if data_args.use_auth_token else None, + use_auth_token=False + ) + + # Log on each process the small summary: + mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, " + + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, " + + f"mixed-precision training: {mixed_precision}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + logger.info("Training/evaluation parameters %s", training_args) + + # 3. Detecting last checkpoint and eventually continue from last checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # 4. Load dataset + raw_datasets = DatasetDict() + + if training_args.do_train: + raw_datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=data_args.train_split_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + + if training_args.do_eval: + raw_datasets["eval"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=data_args.eval_split_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + + config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens}) + + # SpecAugment for whisper models + if getattr(config, "model_type", None) == "whisper": + config.update({"apply_spec_augment": model_args.apply_spec_augment}) + + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + if model_args.freeze_feature_encoder: + model.freeze_feature_encoder() + + if model_args.freeze_encoder: + model.freeze_encoder() + model.model.encoder.gradient_checkpointing = False + + if data_args.language is not None: + # We only need to set the task id when the language is specified (i.e. in a multilingual setting) + tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task) + + # 6. Resample speech dataset if necessary + dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate + if dataset_sampling_rate != feature_extractor.sampling_rate: + logger.warning( + f"The dataset sampling rate ({dataset_sampling_rate}) is different from the feature extractor one" + f" ({feature_extractor.sampling_rate}).Data resampling should be done." + ) + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate + min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + # if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis + forward_attention_mask = ( + getattr(config, "model_type", None) == "whisper" + and getattr(config, "apply_spec_augment", False) + and getattr(config, "mask_time_prob", 0) > 0 + ) + + if data_args.max_train_samples is not None: + raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) + + if data_args.max_eval_samples is not None: + raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples)) + + def prepare_dataset(batch): + # process audio + sample = batch[audio_column_name] + inputs = feature_extractor( + sample["array"], sampling_rate=sample["sampling_rate"], return_attention_mask=forward_attention_mask + ) + # process audio length + batch[model_input_name] = inputs.get(model_input_name)[0] + batch["input_length"] = len(sample["array"]) + if forward_attention_mask: + batch["attention_mask"] = inputs.get("attention_mask")[0] + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + batch["labels"] = tokenizer(input_str).input_ids + return batch + + with training_args.main_process_first(desc="dataset map pre-processing"): + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=data_args.preprocessing_num_workers, + desc="preprocess train dataset", + ) + + # filter data that is shorter than min_input_length or longer than + # max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metric + metric = evaluate.load("wer") + + def compute_metrics(pred): + pred_ids = pred.predictions + + pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id + + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) + + wer = metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer} + + # 9. Create a single speech processor + # make sure all processes wait until data is saved + with training_args.main_process_first(): + # only the main process saves them + if is_main_process(training_args.local_rank): + # save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + # 10. Define data collator + data_collator = DataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + decoder_start_token_id=model.config.decoder_start_token_id, + forward_attention_mask=forward_attention_mask, + ) + + # 11. Initialize Trainer + trainer = GaudiSeq2SeqTrainer( + model=model, + gaudi_config=gaudi_config, + args=training_args, + train_dataset=vectorized_datasets["train"] if training_args.do_train else None, + eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, + tokenizer=feature_extractor, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + # 12. Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the feature extractor too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(vectorized_datasets["train"]) + ) + metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"])) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # 13. Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate( + metric_key_prefix="eval", + max_length=training_args.generation_max_length, + num_beams=training_args.generation_num_beams, + ) + max_eval_samples = ( + data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"]) + ) + metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"])) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # 14. Write Training Stats + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "automatic-speech-recognition"} + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + return results + + +if __name__ == "__main__": + main()