diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 05d77fe9..79df00b3 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -40,6 +40,7 @@ jobs: - swin - t5 - vit + - whisper executorch-version: ['0.4.0', '0.6.0rc', 'nightly'] python-version: ['3.10', '3.11', '3.12'] os: [macos-15] diff --git a/README.md b/README.md index 6819c6ae..76ae69d6 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,8 @@ We currently support a wide range of popular transformer models, including encod πŸš€ Coming more soon... ### Audio Models -πŸ”Š Coming later +#### Encoder-decoder models +- [Whisper](https://huggingface.co/openai/whisper-tiny): OpenAI's `Whisper` and its variants *πŸ“Œ Note: This list is continuously expanding. As we continue to expand support, more models will be added.* diff --git a/optimum/executorch/__init__.py b/optimum/executorch/__init__.py index 0b3eba1c..07b72a6b 100644 --- a/optimum/executorch/__init__.py +++ b/optimum/executorch/__init__.py @@ -23,6 +23,7 @@ "ExecuTorchModelForImageClassification", "ExecuTorchModelForMaskedLM", "ExecuTorchModelForSeq2SeqLM", + "ExecuTorchModelForSpeechSeq2Seq", ], } @@ -32,6 +33,7 @@ ExecuTorchModelForImageClassification, ExecuTorchModelForMaskedLM, ExecuTorchModelForSeq2SeqLM, + ExecuTorchModelForSpeechSeq2Seq, ) else: import sys diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 629cc176..e156ef65 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -29,6 +29,7 @@ AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForSeq2SeqLM, + AutoModelForSpeechSeq2Seq, PretrainedConfig, PreTrainedTokenizer, add_start_docstrings, @@ -871,3 +872,183 @@ def forward( def generate(self): raise NotImplementedError + + +class ExecuTorchModelForSpeechSeq2Seq(ExecuTorchModelBase): + """ + A SpeechSeq2Seq ExecuTorch model for inference using the ExecuTorch Runtime. + + This class provides an interface for loading, running, and generating outputs from a seq2seq language model + optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models + compatible with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForSpeechSeq2Seq`. + model (`ExecuTorchModule`): + The loaded ExecuTorch model. + use_kv_cache (`bool`): + Whether key-value caching is enabled. For performance reasons, the exported model is + optimized to use a static cache. + max_cache_size (`int`): + Maximum sequence length supported by the cache. + max_batch_size (`int`): + Maximum supported batch size. + dtype (`str`): + Data type of the model parameters. + bos_token_id (`int`): + Beginning-of-sequence token ID. + eos_token_id (`int`): + End-of-sequence token ID. + vocab_size (`int`): + Size of the model vocabulary. + """ + + auto_model_class = AutoModelForSpeechSeq2Seq + + def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"): + super().__init__(models=models, config=config) + if not hasattr(self, "encoder"): + raise AttributeError("Expected attribute 'encoder' not found in the instance.") + if not hasattr(self, "decoder"): + raise AttributeError("Expected attribute 'decoder' not found in the instance.") + metadata = self.decoder.method_names() + if "use_kv_cache" in metadata: + self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] + if "get_max_seq_len" in metadata: + self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0] + if "get_max_batch_size" in metadata: + self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0] + if "get_dtype" in metadata: + self.dtype = self.decoder.run_method("get_dtype")[0] + if "get_bos_id" in metadata: + self.bos_token_id = self.decoder.run_method("get_bos_id")[0] + if "get_eos_id" in metadata: + self.eos_token_id = self.decoder.run_method("get_eos_id")[0] + if "get_vocab_size" in metadata: + self.vocab_size = self.decoder.run_method("get_vocab_size")[0] + if "max_hidden_seq_length" in metadata: + self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0] + if "decoder_start_token_id" in metadata: + self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0] + + def forward( + self, + input_features: torch.Tensor, + decoder_input_ids: torch.Tensor, + cache_position: torch.Tensor, + encoder_outputs: Optional[torch.Tensor] = None, + ): + is_first_prediction = encoder_outputs is None + self.stats.on_model_execution_start() + if is_first_prediction: + encoder_outputs = self.encoder.forward((input_features,))[0] + self.stats.on_prompt_eval_end() + + result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + self.stats.on_model_execution_end() + return result + + def generate( + self, + input_features: torch.Tensor, + echo: bool = False, + pos_base: int = 0, + max_seq_len: Optional[int] = None, + ) -> List[int]: + """ + Generate tokens from a prompt using the ExecuTorch model. + + Args: + input_features (List[int]): + Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] or + [1, 128, 3000]. For details, check out the processor config. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `False`. + pos_base (`int`, *optional*): + Base position for the prompt tokens. Defaults to 0. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + + Returns: + List[int]: List of generated token IDs. + """ + self.device = torch.device("cpu") + if max_seq_len is None: + # Default to max_cache_size if max_seq_len is not specified + max_seq_len = self.max_cache_size + elif max_seq_len > self.max_cache_size: + logging.warning( + f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." + ) + max_seq_len = self.max_cache_size + + if not hasattr(self, "decoder_start_token_id"): + raise AttributeError("'decoder_start_token_id' is missing in the metadata of the PTE.") + decoder_input_ids = torch.tensor([[self.decoder_start_token_id]], dtype=torch.long) + log_mel = input_features + encoder_outputs = None + generated_ids = [] + first_token_generated = False + + # Generate tokens one by one + for i in range(max_seq_len - 1): + # Run decoder for next token prediction + cache_position = torch.tensor([i], dtype=torch.long) + self.stats.on_sampling_begin() + logits, encoder_outputs = self.forward(log_mel, decoder_input_ids, cache_position, encoder_outputs) + self.stats.on_sampling_end() + if not first_token_generated: + self.stats.on_first_token() + first_token_generated = True + # Get next token + next_token = torch.argmax(logits[:, -1, :], dim=-1).item() + generated_ids.append(next_token) + self.stats.set_num_generated_tokens(len(generated_ids) - 1) # Don't count decoder_start_token + + # Update input for next iteration + decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) + + # Check if EOS token + if next_token == self.eos_token_id: + break + + return generated_ids + + def transcribe( + self, + tokenizer: "PreTrainedTokenizer", + input_features: torch.Tensor, + echo: bool = True, + max_seq_len: Optional[int] = None, + ): + """ + Perform text generation task for a given prompt using the ExecuTorch model. + + Args: + tokenizer (`PreTrainedTokenizer`): + The tokenizer used to encode and decode the prompt and output. + input_features (`str`): + Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] or + [1, 128, 3000]. For details, check out the processor config. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `True`. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + """ + self.tokenizer = tokenizer + + self.stats.reset() + self.stats.on_inference_start() + generated_tokens = self.generate( + input_features=input_features, + echo=echo, + max_seq_len=max_seq_len, + ) + self.stats.on_inference_end() + self.stats.print_report() + return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 5b8c37b6..64bbe7dc 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -17,7 +17,13 @@ import torch from torch.export import ExportedProgram from torch.nn.attention import SDPBackend -from transformers import PreTrainedModel, StaticCache +from transformers import ( + AutoProcessor, + PreTrainedModel, + StaticCache, + T5ForConditionalGeneration, + WhisperForConditionalGeneration, +) from transformers.generation.configuration_utils import GenerationConfig from optimum.utils.import_utils import is_transformers_version @@ -153,7 +159,7 @@ def __init__(self, encoder_model): self.config = encoder_model.config def forward(self, input_ids): - return self.encoder(input_ids=input_ids).last_hidden_state + return self.encoder(input_ids).last_hidden_state class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module): @@ -168,7 +174,10 @@ def __init__(self, model, max_static_cache_length, batch_size): # Get the decoder component self.decoder = model.get_decoder() - self.lm_head = model.lm_head + if isinstance(model, WhisperForConditionalGeneration): + self.proj_out = model.proj_out + else: + self.proj_out = model.lm_head self.config = model.config # Initialize static cache @@ -195,10 +204,9 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): cache_position=cache_position, ) - # Apply language model head - lm_logits = self.lm_head(outputs[0]) - - return lm_logits + # Apply linear projection (lm head) to obtain logits + logits = self.proj_out(outputs[0]) + return logits class Seq2SeqLMExportableModule(torch.nn.Module): @@ -225,6 +233,15 @@ def __init__( "max_cache_len": max_cache_length, }, ) + if isinstance(self.full_model, WhisperForConditionalGeneration): + self._processor = AutoProcessor.from_pretrained(model.config._name_or_path) + self._expected_encoder_input_shape = torch.Size( + ( + 1, + self._processor.feature_extractor.feature_size, + self._processor.feature_extractor.nb_max_frames, + ) + ) additional_configs = {} additional_configs["max_hidden_seq_length"] = max_hidden_seq_length # Metadata to be recorded in the pte model file @@ -240,14 +257,25 @@ def _export_encoder(self, encoder_input_ids): wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() # Define dynamic sequence length for encoder - seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length) + if isinstance(self.full_model, WhisperForConditionalGeneration): + assert ( + encoder_input_ids.shape == self._expected_encoder_input_shape + ), f"""This version of Whisper only accepts encoder input of shape {self._expected_encoder_input_shape}, passed shape: {encoder_input_ids.shape}. + For more infromation, please refer to the Whisper preprocessor config.""" + dynamic_shapes = None + elif isinstance(self.full_model, T5ForConditionalGeneration): + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}} + else: + raise ValueError( + f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule encoder export." + ) # Export the encoder with torch.no_grad(): exported_encoder = torch.export.export( - wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True + wrapped_encoder, (encoder_input_ids,), dynamic_shapes=dynamic_shapes, strict=True ) - return exported_encoder def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position): @@ -261,19 +289,27 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi .eval() ) - # Define dynamic dimension for encoder output sequence length - encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + if isinstance(self.full_model, WhisperForConditionalGeneration): + dynamic_shapes = None + elif isinstance(self.full_model, T5ForConditionalGeneration): + # Define dynamic dimension for encoder output sequence length + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + dynamic_shapes = { + "decoder_input_ids": None, + "encoder_hidden_states": {1: encoder_seq_len_dim}, + "cache_position": None, + } + else: + raise ValueError( + f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule decoder export." + ) # Export the decoder with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): exported_decoder = torch.export.export( wrapped_decoder, (decoder_input_ids, encoder_hidden_states, cache_position), - dynamic_shapes={ - "decoder_input_ids": None, - "encoder_hidden_states": {1: encoder_seq_len_dim}, - "cache_position": None, - }, + dynamic_shapes=dynamic_shapes, strict=True, ) @@ -286,21 +322,26 @@ def export( encoder_hidden_states=None, cache_position=None, ) -> Dict[str, ExportedProgram]: - example_encoder_input_ids = ( - encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long) - ) + if encoder_input_ids is None: + if isinstance(self.full_model, WhisperForConditionalGeneration): + example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape) + else: + example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long) + else: + example_encoder_input_ids = encoder_input_ids + + self.exported_encoder = self._export_encoder(example_encoder_input_ids) + + if not encoder_hidden_states: + example_encoder_hidden_states = self.exported_encoder.module()(example_encoder_input_ids) + else: + example_encoder_hidden_states = encoder_hidden_states + example_decoder_input_ids = ( decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long) - ) # Start token - example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) - example_encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else torch.zeros( - (self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32 - ) ) - self.exported_encoder = self._export_encoder(example_encoder_input_ids) + example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) + self.exported_decoder = self._export_decoder( example_decoder_input_ids, example_encoder_hidden_states, example_cache_position ) diff --git a/optimum/exporters/executorch/tasks/asr.py b/optimum/exporters/executorch/tasks/asr.py new file mode 100644 index 00000000..bc20bdc1 --- /dev/null +++ b/optimum/exporters/executorch/tasks/asr.py @@ -0,0 +1,58 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoModelForSpeechSeq2Seq + +from ..integrations import Seq2SeqLMExportableModule +from ..task_registry import register_task + + +# NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py. +# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier. +@register_task("automatic-speech-recognition") +def load_seq2seq_speech_model(model_name_or_path: str, **kwargs) -> Seq2SeqLMExportableModule: + """ + Loads a model for speech seq2seq and registers it under the task + 'automatic-speech-recognition' using Hugging Face's `AutoModelForSpeechSeq2Seq`. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="openai/whisper-tiny"` or `mode_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - max_hidden_seq_length (int, optional): + Maximum hidden sequence length (default: 4096). + - max_cache_length (int, optional): + Maximum sequence length for generation (default: 1024). + + Returns: + Seq2SeqLMExportableModule: + An instance of `Seq2SeqLMExportableModule` for exporting and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + max_hidden_seq_length = kwargs.get("max_hidden_seq_length", 4096) + max_cache_length = kwargs.get("max_cache_length", 1024) + + full_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path).to(device).eval() + return Seq2SeqLMExportableModule( + full_model, + batch_size=batch_size, + max_hidden_seq_length=max_hidden_seq_length, + max_cache_length=max_cache_length, + ) diff --git a/setup.py b/setup.py index f990b3cc..ce24db7a 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,9 @@ "pytest", "safetensors", "sentencepiece", + "numba!=0.58.0", # Due to the bug https://github.com/numba/numba/issues/9209 + "librosa", + "soundfile", ] diff --git a/tests/models/test_modeling_whisper.py b/tests/models/test_modeling_whisper.py new file mode 100644 index 00000000..f6383823 --- /dev/null +++ b/tests/models/test_modeling_whisper.py @@ -0,0 +1,78 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import subprocess +import tempfile +import unittest + +import pytest +from datasets import load_dataset +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoProcessor, AutoTokenizer +from transformers.testing_utils import slow + +from optimum.executorch import ExecuTorchModelForSpeechSeq2Seq + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_whisper_export_to_executorch(self): + model_id = "openai/whisper-tiny" + task = "automatic-speech-recognition" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/encoder.pte")) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/decoder.pte")) + + @slow + @pytest.mark.run_slow + def test_whisper_transcription(self): + model_id = "openai/whisper-tiny" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = ExecuTorchModelForSpeechSeq2Seq.from_pretrained(model_id, recipe="xnnpack") + processor = AutoProcessor.from_pretrained(model_id) + + self.assertIsInstance(model, ExecuTorchModelForSpeechSeq2Seq) + self.assertTrue(hasattr(model, "encoder")) + self.assertIsInstance(model.encoder, ExecuTorchModule) + self.assertTrue(hasattr(model, "decoder")) + self.assertIsInstance(model.decoder, ExecuTorchModule) + + dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") + sample = dataset[0]["audio"] + + input_features = processor( + sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"] + ).input_features + # Current implementation of the transcibe method accepts up to 30 seconds of audio, therefore I trim the audio here. + input_features_trimmed = input_features[:, :, :3000].contiguous() + + generated_transcription = model.transcribe(tokenizer, input_features_trimmed) + expected_text = " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all, and can discover that." + logging.info( + f"\nExpected transcription:\n\t{expected_text}\nGenerated transcription:\n\t{generated_transcription}" + ) + self.assertEqual(generated_transcription, expected_text)