diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py index d32f50371e68..294965914d00 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py @@ -97,6 +97,22 @@ class ModelArguments: freeze_feature_encoder: bool = field( default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} ) + freeze_encoder: bool = field( + default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."} + ) + forced_decoder_ids: List[List[int]] = field( + default=None, + metadata={ + "help": ( + "A list of pairs of integers which indicates a mapping from generation indices to token indices " + "that will be forced before sampling. For example, [[0, 123]] means the first generated token " + "will always be a token of index 123." + ) + }, + ) + suppress_tokens: List[int] = field( + default=None, metadata={"help": "A list of tokens that will be suppressed at generation."} + ) @dataclass @@ -187,6 +203,19 @@ class DataTrainingArguments: default=True, metadata={"help": "Whether the target text should be lower cased."}, ) + language: str = field( + default=None, + metadata={ + "help": ( + "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning " + "only. For English speech recognition, it should be set to `None`." + ) + }, + ) + task: str = field( + default="transcribe", + metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."}, + ) @dataclass @@ -194,7 +223,7 @@ class DataCollatorSpeechSeq2SeqWithPadding: """ Data collator that will dynamically pad the inputs received. Args: - processor ([`Wav2Vec2Processor`]) + processor ([`WhisperProcessor`]) The processor used for processing the data. decoder_start_token_id (`int`) The begin-of-sentence of the decoder. @@ -206,7 +235,8 @@ class DataCollatorSpeechSeq2SeqWithPadding: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need # different padding methods - input_features = [{"input_values": feature["input_values"]} for feature in features] + model_input_name = self.processor.model_input_names[0] + input_features = [{model_input_name: feature[model_input_name]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") @@ -333,6 +363,8 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) + config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens}) + feature_extractor = AutoFeatureExtractor.from_pretrained( model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, @@ -360,6 +392,14 @@ def main(): if model_args.freeze_feature_encoder: model.freeze_feature_encoder() + if model_args.freeze_encoder: + model.freeze_encoder() + model.model.encoder.gradient_checkpointing = False + + if data_args.language is not None: + # We only need to set the task id when the language is specified (i.e. in a multilingual setting) + tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task) + # 6. Resample speech dataset if necessary dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate if dataset_sampling_rate != feature_extractor.sampling_rate: @@ -388,8 +428,8 @@ def prepare_dataset(batch): sample = batch[audio_column_name] inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) # process audio length - batch[model_input_name] = inputs.input_values[0] - batch["input_length"] = len(batch["input_values"]) + batch[model_input_name] = inputs.get(model_input_name)[0] + batch["input_length"] = len(sample["array"]) # process targets input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] @@ -452,7 +492,8 @@ def compute_metrics(pred): # 10. Define data collator data_collator = DataCollatorSpeechSeq2SeqWithPadding( - processor=processor, decoder_start_token_id=model.config.decoder_start_token_id + processor=processor, + decoder_start_token_id=model.config.decoder_start_token_id, ) # 11. Initialize Trainer @@ -492,7 +533,9 @@ def compute_metrics(pred): if training_args.do_eval: logger.info("*** Evaluate ***") metrics = trainer.evaluate( - metric_key_prefix="eval", max_length=model.config.max_length, num_beams=model.config.num_beams + metric_key_prefix="eval", + max_length=training_args.generation_max_length, + num_beams=training_args.generation_num_beams, ) max_eval_samples = ( data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])