From c5a899ad0f3c0715b76e87a4921298e4c7f57a18 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Fri, 4 Apr 2025 12:12:27 +0200 Subject: [PATCH 01/19] feat: add ASR task for Whisper --- optimum/exporters/executorch/tasks/asr.py | 58 +++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 optimum/exporters/executorch/tasks/asr.py diff --git a/optimum/exporters/executorch/tasks/asr.py b/optimum/exporters/executorch/tasks/asr.py new file mode 100644 index 00000000..bc20bdc1 --- /dev/null +++ b/optimum/exporters/executorch/tasks/asr.py @@ -0,0 +1,58 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoModelForSpeechSeq2Seq + +from ..integrations import Seq2SeqLMExportableModule +from ..task_registry import register_task + + +# NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py. +# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier. +@register_task("automatic-speech-recognition") +def load_seq2seq_speech_model(model_name_or_path: str, **kwargs) -> Seq2SeqLMExportableModule: + """ + Loads a model for speech seq2seq and registers it under the task + 'automatic-speech-recognition' using Hugging Face's `AutoModelForSpeechSeq2Seq`. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="openai/whisper-tiny"` or `mode_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - max_hidden_seq_length (int, optional): + Maximum hidden sequence length (default: 4096). + - max_cache_length (int, optional): + Maximum sequence length for generation (default: 1024). + + Returns: + Seq2SeqLMExportableModule: + An instance of `Seq2SeqLMExportableModule` for exporting and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + max_hidden_seq_length = kwargs.get("max_hidden_seq_length", 4096) + max_cache_length = kwargs.get("max_cache_length", 1024) + + full_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path).to(device).eval() + return Seq2SeqLMExportableModule( + full_model, + batch_size=batch_size, + max_hidden_seq_length=max_hidden_seq_length, + max_cache_length=max_cache_length, + ) From b2401ff282a1ed05a3301d7d16a2b08830b22cd5 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Fri, 4 Apr 2025 12:16:04 +0200 Subject: [PATCH 02/19] feat: modify integrations.py to make Seq2seqLM export work with Whisper --- optimum/exporters/executorch/integrations.py | 75 ++++++++++++-------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 5b8c37b6..9d401795 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -21,6 +21,8 @@ from transformers.generation.configuration_utils import GenerationConfig from optimum.utils.import_utils import is_transformers_version +from transformers import PreTrainedModel, StaticCache, WhisperForConditionalGeneration +from transformers.generation.configuration_utils import GenerationConfig from .utils import save_config_to_constant_methods @@ -153,7 +155,7 @@ def __init__(self, encoder_model): self.config = encoder_model.config def forward(self, input_ids): - return self.encoder(input_ids=input_ids).last_hidden_state + return self.encoder(input_ids).last_hidden_state class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module): @@ -168,7 +170,10 @@ def __init__(self, model, max_static_cache_length, batch_size): # Get the decoder component self.decoder = model.get_decoder() - self.lm_head = model.lm_head + if isinstance(model, WhisperForConditionalGeneration): + self.proj_out = model.proj_out + else: + self.proj_out = model.lm_head self.config = model.config # Initialize static cache @@ -195,10 +200,9 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): cache_position=cache_position, ) - # Apply language model head - lm_logits = self.lm_head(outputs[0]) - - return lm_logits + # Apply linear projection (lm head) to obtain logits + logits = self.proj_out(outputs[0]) + return logits class Seq2SeqLMExportableModule(torch.nn.Module): @@ -240,14 +244,20 @@ def _export_encoder(self, encoder_input_ids): wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() # Define dynamic sequence length for encoder - seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length) + if isinstance(self.full_model, WhisperForConditionalGeneration): + assert encoder_input_ids.shape == torch.Size([1, 80, 3000]), ( + f"Whisper only accepts a log-mel spectrogram of shape [1, 80, 3000], passed shape: {encoder_input_ids.shape}" + ) + dynamic_shapes = None + else: + seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length) + dynamic_shapes = {"input_ids": {1: seq_len_dim}} # Export the encoder with torch.no_grad(): exported_encoder = torch.export.export( - wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True + wrapped_encoder, (encoder_input_ids,), dynamic_shapes=dynamic_shapes, strict=True ) - return exported_encoder def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position): @@ -261,19 +271,23 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi .eval() ) - # Define dynamic dimension for encoder output sequence length - encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + if isinstance(self.full_model, WhisperForConditionalGeneration): + dynamic_shapes = None + else: + # Define dynamic dimension for encoder output sequence length + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + dynamic_shapes = { + "decoder_input_ids": None, + "encoder_hidden_states": {1: encoder_seq_len_dim}, + "cache_position": None, + } # Export the decoder with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): exported_decoder = torch.export.export( wrapped_decoder, (decoder_input_ids, encoder_hidden_states, cache_position), - dynamic_shapes={ - "decoder_input_ids": None, - "encoder_hidden_states": {1: encoder_seq_len_dim}, - "cache_position": None, - }, + dynamic_shapes=dynamic_shapes, strict=True, ) @@ -286,21 +300,26 @@ def export( encoder_hidden_states=None, cache_position=None, ) -> Dict[str, ExportedProgram]: - example_encoder_input_ids = ( - encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long) - ) + if encoder_input_ids is None: + if isinstance(self.full_model, WhisperForConditionalGeneration): + example_encoder_input_ids = torch.rand((1, 80, 3000)) + else: + example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long) + else: + example_encoder_input_ids = encoder_input_ids + + self.exported_encoder = self._export_encoder(example_encoder_input_ids) + + if not encoder_hidden_states: + example_encoder_hidden_states = self.exported_encoder.module()(example_encoder_input_ids) + else: + example_encoder_hidden_states = encoder_hidden_states + example_decoder_input_ids = ( decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long) - ) # Start token - example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) - example_encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else torch.zeros( - (self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32 - ) ) - self.exported_encoder = self._export_encoder(example_encoder_input_ids) + example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) + self.exported_decoder = self._export_decoder( example_decoder_input_ids, example_encoder_hidden_states, example_cache_position ) From 6e050d7035d063c326e686dbf1e8c57815b05c79 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Fri, 4 Apr 2025 12:16:19 +0200 Subject: [PATCH 03/19] feat: add ExecuTorchModelForSpeechSeq2Seq to modeling.py --- optimum/executorch/modeling.py | 163 +++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 629cc176..abc7b971 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -29,6 +29,7 @@ AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForSeq2SeqLM, + AutoModelForSpeechSeq2Seq, PretrainedConfig, PreTrainedTokenizer, add_start_docstrings, @@ -871,3 +872,165 @@ def forward( def generate(self): raise NotImplementedError + + +class ExecuTorchModelForSpeechSeq2Seq(ExecuTorchModelBase): + """ + A SpeechSeq2Seq ExecuTorch model for inference using the ExecuTorch Runtime. + + This class provides an interface for loading, running, and generating outputs from a seq2seq language model + optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models + compatible with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForSpeechSeq2Seq`. + model (`ExecuTorchModule`): + The loaded ExecuTorch model. + use_kv_cache (`bool`): + Whether key-value caching is enabled. For performance reasons, the exported model is + optimized to use a static cache. + max_cache_size (`int`): + Maximum sequence length supported by the cache. + max_batch_size (`int`): + Maximum supported batch size. + dtype (`str`): + Data type of the model parameters. + bos_token_id (`int`): + Beginning-of-sequence token ID. + eos_token_id (`int`): + End-of-sequence token ID. + vocab_size (`int`): + Size of the model vocabulary. + """ + + auto_model_class = AutoModelForSpeechSeq2Seq + + def __init__(self, encoder: "ExecuTorchModule", decoder: "ExecuTorchModule", config: "PretrainedConfig"): + super().__init__(model=None, config=config) + self.et_encoder = encoder + self.et_decoder = decoder + metadata = self.et_decoder.method_names() + if "use_kv_cache" in metadata: + self.use_kv_cache = self.et_decoder.run_method("use_kv_cache")[0] + if "get_max_seq_len" in metadata: + self.max_cache_size = self.et_decoder.run_method("get_max_seq_len")[0] + if "get_max_batch_size" in metadata: + self.max_batch_size = self.et_decoder.run_method("get_max_batch_size")[0] + if "get_dtype" in metadata: + self.dtype = self.et_decoder.run_method("get_dtype")[0] + if "get_bos_id" in metadata: + self.bos_token_id = self.et_decoder.run_method("get_bos_id")[0] + if "get_eos_id" in metadata: + self.eos_token_id = self.et_decoder.run_method("get_eos_id")[0] + if "get_vocab_size" in metadata: + self.vocab_size = self.et_decoder.run_method("get_vocab_size")[0] + if "max_hidden_seq_length" in metadata: + self.max_hidden_seq_length = self.et_decoder.run_method("max_hidden_seq_length")[0] + if "decoder_start_token_id" in metadata: + self.decoder_start_token_id = self.et_decoder.run_method("decoder_start_token_id")[0] + + def forward( + self, + input_features: torch.Tensor, + decoder_input_ids: torch.Tensor, + cache_position: torch.Tensor, + encoder_outputs: Optional[torch.Tensor] = None, + ): + # Encode if needed (first prediction pass) + is_first_prediction = encoder_outputs is None + if is_first_prediction: + encoder_outputs = self.et_encoder.forward((input_features,))[0] + + return (self.et_decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + + def generate( + self, + input_features: torch.Tensor, + echo: bool = False, + pos_base: int = 0, + max_seq_len: Optional[int] = None, + ) -> List[int]: + """ + Generate tokens from a prompt using the ExecuTorch model. + + Args: + input_features (List[int]): + Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `False`. + pos_base (`int`, *optional*): + Base position for the prompt tokens. Defaults to 0. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + + Returns: + List[int]: List of generated token IDs. + """ + self.device = torch.device("cpu") + if max_seq_len is None: + # Default to max_cache_size if max_seq_len is not specified + max_seq_len = self.max_cache_size + elif max_seq_len > self.max_cache_size: + logging.warning( + f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." + ) + max_seq_len = self.max_cache_size + + if not hasattr(self, "decoder_start_token_id"): + raise AttributeError("'decoder_start_token_id' is missing in the metadata of the PTE.") + decoder_input_ids = torch.tensor([[self.decoder_start_token_id]], dtype=torch.long) + log_mel = input_features + encoder_outputs = None + generated_ids = [] + + # Generate tokens one by one + for i in range(max_seq_len - 1): + # Run decoder for next token prediction + cache_position = torch.tensor([i], dtype=torch.long) + logits, encoder_outputs = self.forward(log_mel, decoder_input_ids, cache_position, encoder_outputs) + + # Get next token + next_token = torch.argmax(logits[:, -1, :], dim=-1).item() + generated_ids.append(next_token) + + # Update input for next iteration + decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) + + # Check if EOS token + if next_token == self.eos_token_id: + break + + return generated_ids + + def transcribe( + self, + tokenizer: "PreTrainedTokenizer", + input_features: torch.Tensor, + echo: bool = True, + max_seq_len: Optional[int] = None, + ): + """ + Perform text generation task for a given prompt using the ExecuTorch model. + + Args: + tokenizer (`PreTrainedTokenizer`): + The tokenizer used to encode and decode the prompt and output. + input_features (`str`): + Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `True`. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + """ + self.tokenizer = tokenizer + generated_tokens = self.generate( + input_ids=input_features, + echo=echo, + max_seq_len=max_seq_len, + ) + return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) From 62800ffcce7445137953ca09d1630ab894cf6f14 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 7 Apr 2025 09:39:39 +0200 Subject: [PATCH 04/19] post rebase constructor fix --- optimum/executorch/modeling.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index abc7b971..1fc03b38 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -873,7 +873,6 @@ def forward( def generate(self): raise NotImplementedError - class ExecuTorchModelForSpeechSeq2Seq(ExecuTorchModelBase): """ A SpeechSeq2Seq ExecuTorch model for inference using the ExecuTorch Runtime. @@ -906,29 +905,31 @@ class ExecuTorchModelForSpeechSeq2Seq(ExecuTorchModelBase): auto_model_class = AutoModelForSpeechSeq2Seq - def __init__(self, encoder: "ExecuTorchModule", decoder: "ExecuTorchModule", config: "PretrainedConfig"): - super().__init__(model=None, config=config) - self.et_encoder = encoder - self.et_decoder = decoder - metadata = self.et_decoder.method_names() + def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"): + super().__init__(models=models, config=config) + if not hasattr(self, "encoder"): + raise AttributeError("Expected attribute 'encoder' not found in the instance.") + if not hasattr(self, "decoder"): + raise AttributeError("Expected attribute 'decoder' not found in the instance.") + metadata = self.decoder.method_names() if "use_kv_cache" in metadata: - self.use_kv_cache = self.et_decoder.run_method("use_kv_cache")[0] + self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] if "get_max_seq_len" in metadata: - self.max_cache_size = self.et_decoder.run_method("get_max_seq_len")[0] + self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0] if "get_max_batch_size" in metadata: - self.max_batch_size = self.et_decoder.run_method("get_max_batch_size")[0] + self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0] if "get_dtype" in metadata: - self.dtype = self.et_decoder.run_method("get_dtype")[0] + self.dtype = self.decoder.run_method("get_dtype")[0] if "get_bos_id" in metadata: - self.bos_token_id = self.et_decoder.run_method("get_bos_id")[0] + self.bos_token_id = self.decoder.run_method("get_bos_id")[0] if "get_eos_id" in metadata: - self.eos_token_id = self.et_decoder.run_method("get_eos_id")[0] + self.eos_token_id = self.decoder.run_method("get_eos_id")[0] if "get_vocab_size" in metadata: - self.vocab_size = self.et_decoder.run_method("get_vocab_size")[0] + self.vocab_size = self.decoder.run_method("get_vocab_size")[0] if "max_hidden_seq_length" in metadata: - self.max_hidden_seq_length = self.et_decoder.run_method("max_hidden_seq_length")[0] + self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0] if "decoder_start_token_id" in metadata: - self.decoder_start_token_id = self.et_decoder.run_method("decoder_start_token_id")[0] + self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0] def forward( self, From a108a6f184a2edac56cf78beb8725ab093c4eb81 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 7 Apr 2025 10:28:55 +0200 Subject: [PATCH 05/19] docs: Update README with Whisper --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6819c6ae..e7c107ef 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ We currently support a wide range of popular transformer models, including encod πŸš€ Coming more soon... ### Audio Models -πŸ”Š Coming later +- [Whisper](https://huggingface.co/openai/whisper-tiny): OpenAI's `Whisper` and its variants *πŸ“Œ Note: This list is continuously expanding. As we continue to expand support, more models will be added.* From 5106c693ffd56ca8ff0698784928fd0332054fbf Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 7 Apr 2025 10:29:18 +0200 Subject: [PATCH 06/19] chore: add ExecuTorchModelForSpeechSeq2Seq to __init__.py --- optimum/executorch/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/executorch/__init__.py b/optimum/executorch/__init__.py index 0b3eba1c..07b72a6b 100644 --- a/optimum/executorch/__init__.py +++ b/optimum/executorch/__init__.py @@ -23,6 +23,7 @@ "ExecuTorchModelForImageClassification", "ExecuTorchModelForMaskedLM", "ExecuTorchModelForSeq2SeqLM", + "ExecuTorchModelForSpeechSeq2Seq", ], } @@ -32,6 +33,7 @@ ExecuTorchModelForImageClassification, ExecuTorchModelForMaskedLM, ExecuTorchModelForSeq2SeqLM, + ExecuTorchModelForSpeechSeq2Seq, ) else: import sys From 42a9cd2cc21e50041ffc6b2a7b8e0f7323ba7388 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 7 Apr 2025 10:30:17 +0200 Subject: [PATCH 07/19] fix: post-rebase modelling fix --- optimum/executorch/modeling.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 1fc03b38..eeba7050 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -24,6 +24,8 @@ import torch from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE + +from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch from transformers import ( AutoModelForCausalLM, AutoModelForImageClassification, @@ -36,8 +38,6 @@ ) from transformers.utils import is_offline_mode -from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch - from ..exporters import TasksManager from ..exporters.executorch import main_export from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel @@ -873,6 +873,7 @@ def forward( def generate(self): raise NotImplementedError + class ExecuTorchModelForSpeechSeq2Seq(ExecuTorchModelBase): """ A SpeechSeq2Seq ExecuTorch model for inference using the ExecuTorch Runtime. @@ -938,12 +939,11 @@ def forward( cache_position: torch.Tensor, encoder_outputs: Optional[torch.Tensor] = None, ): - # Encode if needed (first prediction pass) is_first_prediction = encoder_outputs is None if is_first_prediction: - encoder_outputs = self.et_encoder.forward((input_features,))[0] + encoder_outputs = self.encoder.forward((input_features,))[0] - return (self.et_decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + return (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) def generate( self, @@ -1030,7 +1030,7 @@ def transcribe( """ self.tokenizer = tokenizer generated_tokens = self.generate( - input_ids=input_features, + input_features=input_features, echo=echo, max_seq_len=max_seq_len, ) From de1d7f83ee93f4f90110c5e94c778c3d79322f2a Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 7 Apr 2025 10:34:54 +0200 Subject: [PATCH 08/19] tests: add Whisper testing --- tests/models/test_modeling_whisper.py | 72 +++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/models/test_modeling_whisper.py diff --git a/tests/models/test_modeling_whisper.py b/tests/models/test_modeling_whisper.py new file mode 100644 index 00000000..1e54ba6b --- /dev/null +++ b/tests/models/test_modeling_whisper.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import subprocess +import tempfile +import unittest + +import pytest +import torch + +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from optimum.executorch import ExecuTorchModelForSpeechSeq2Seq +from transformers import AutoTokenizer +from transformers.testing_utils import slow + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_whisper_export_to_executorch(self): + model_id = "openai/whisper-tiny" + task = "automatic-speech-recognition" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/encoder.pte")) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/decoder.pte")) + + @slow + @pytest.mark.run_slow + def test_whisper_transcription(self): + model_id = "openai/whisper-tiny" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = ExecuTorchModelForSpeechSeq2Seq.from_pretrained(model_id, recipe="xnnpack") + + self.assertIsInstance(model, ExecuTorchModelForSpeechSeq2Seq) + self.assertTrue(hasattr(model, "encoder")) + self.assertIsInstance(model.encoder, ExecuTorchModule) + self.assertTrue(hasattr(model, "decoder")) + self.assertIsInstance(model.decoder, ExecuTorchModule) + + # Set manual seed for reproducibility, Whisper could possibly hallucinate tokens + # in some cases if this is not set. + torch.manual_seed(11) + input_features = torch.rand(1, 80, 3000) + generated_transcription = model.transcribe(tokenizer, input_features) + expected_text = "" + logging.info( + f"\nExpected transcription:\n\t{expected_text}\nGenerated transcription:\n\t{generated_transcription}" + ) + self.assertEqual(generated_transcription, expected_text) From ba9644fcf7b6c4e34e51c576b003ecf64b1e117e Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 7 Apr 2025 12:34:52 +0200 Subject: [PATCH 09/19] lint: run black formatter --- optimum/exporters/executorch/integrations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 9d401795..4f1b9cb1 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -245,9 +245,9 @@ def _export_encoder(self, encoder_input_ids): # Define dynamic sequence length for encoder if isinstance(self.full_model, WhisperForConditionalGeneration): - assert encoder_input_ids.shape == torch.Size([1, 80, 3000]), ( - f"Whisper only accepts a log-mel spectrogram of shape [1, 80, 3000], passed shape: {encoder_input_ids.shape}" - ) + assert encoder_input_ids.shape == torch.Size( + [1, 80, 3000] + ), f"Whisper only accepts a log-mel spectrogram of shape [1, 80, 3000], passed shape: {encoder_input_ids.shape}" dynamic_shapes = None else: seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length) From 462acc02e12587af5390e8ef97af5a7c8a88dbc4 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 7 Apr 2025 12:51:30 +0200 Subject: [PATCH 10/19] ci: add Whisper test to CI --- .github/workflows/test_models.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index f3e28c63..5a9a88a3 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -40,6 +40,7 @@ jobs: - swin - t5 - vit + - whisper executorch-version: ['0.4.0', 'nightly'] python-version: ['3.10', '3.11', '3.12'] os: [macos-15] From d71ae4a61d8cf7cb891bde6d4c4e25bee0ba432f Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 8 Apr 2025 17:15:03 +0200 Subject: [PATCH 11/19] chore: review changes --- optimum/executorch/modeling.py | 25 ++++++++++-- optimum/exporters/executorch/integrations.py | 42 +++++++++++++++----- tests/models/test_modeling_whisper.py | 22 ++++++---- 3 files changed, 67 insertions(+), 22 deletions(-) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index eeba7050..8e08af3a 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -940,10 +940,14 @@ def forward( encoder_outputs: Optional[torch.Tensor] = None, ): is_first_prediction = encoder_outputs is None + self.stats.on_model_execution_start() if is_first_prediction: encoder_outputs = self.encoder.forward((input_features,))[0] + self.stats.on_prompt_eval_end() - return (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + self.stats.on_model_execution_end() + return result def generate( self, @@ -957,7 +961,8 @@ def generate( Args: input_features (List[int]): - Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] + Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] or + [1, 128, 3000]. For details, check out the processor config. echo (`bool`, *optional*): Whether to include prompt tokens in the generated output. Defaults to `False`. pos_base (`int`, *optional*): @@ -986,16 +991,22 @@ def generate( log_mel = input_features encoder_outputs = None generated_ids = [] + first_token_generated = False # Generate tokens one by one for i in range(max_seq_len - 1): # Run decoder for next token prediction cache_position = torch.tensor([i], dtype=torch.long) + self.stats.on_sampling_begin() logits, encoder_outputs = self.forward(log_mel, decoder_input_ids, cache_position, encoder_outputs) - + self.stats.on_sampling_end() + if not first_token_generated: + self.stats.on_first_token() + first_token_generated = True # Get next token next_token = torch.argmax(logits[:, -1, :], dim=-1).item() generated_ids.append(next_token) + self.stats.set_num_generated_tokens(len(generated_ids) - 1) # Don't count decoder_start_token # Update input for next iteration decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) @@ -1020,7 +1031,8 @@ def transcribe( tokenizer (`PreTrainedTokenizer`): The tokenizer used to encode and decode the prompt and output. input_features (`str`): - Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] + Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] or + [1, 128, 3000]. For details, check out the processor config. echo (`bool`, *optional*): Whether to include prompt tokens in the generated output. Defaults to `True`. max_seq_len (`int`, *optional*): @@ -1029,9 +1041,14 @@ def transcribe( Will be truncated to maximal cache size if larger than `max_cache_size`. """ self.tokenizer = tokenizer + + self.stats.reset() + self.stats.on_inference_start() generated_tokens = self.generate( input_features=input_features, echo=echo, max_seq_len=max_seq_len, ) + self.stats.on_inference_end() + self.stats.print_report() return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 4f1b9cb1..1bdc4ee8 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -17,11 +17,15 @@ import torch from torch.export import ExportedProgram from torch.nn.attention import SDPBackend -from transformers import PreTrainedModel, StaticCache -from transformers.generation.configuration_utils import GenerationConfig from optimum.utils.import_utils import is_transformers_version -from transformers import PreTrainedModel, StaticCache, WhisperForConditionalGeneration +from transformers import ( + AutoProcessor, + PreTrainedModel, + StaticCache, + T5ForConditionalGeneration, + WhisperForConditionalGeneration, +) from transformers.generation.configuration_utils import GenerationConfig from .utils import save_config_to_constant_methods @@ -229,6 +233,15 @@ def __init__( "max_cache_len": max_cache_length, }, ) + if isinstance(self.full_model, WhisperForConditionalGeneration): + self._processor = AutoProcessor.from_pretrained(model.config._name_or_path) + self._expected_encoder_input_shape = torch.Size( + ( + 1, + self._processor.feature_extractor.feature_size, + self._processor.feature_extractor.nb_max_frames, + ) + ) additional_configs = {} additional_configs["max_hidden_seq_length"] = max_hidden_seq_length # Metadata to be recorded in the pte model file @@ -245,13 +258,18 @@ def _export_encoder(self, encoder_input_ids): # Define dynamic sequence length for encoder if isinstance(self.full_model, WhisperForConditionalGeneration): - assert encoder_input_ids.shape == torch.Size( - [1, 80, 3000] - ), f"Whisper only accepts a log-mel spectrogram of shape [1, 80, 3000], passed shape: {encoder_input_ids.shape}" + assert ( + encoder_input_ids.shape == self._expected_encoder_input_shape + ), f"""This version of Whisper only accepts encoder input of shape {self._expected_encoder_input_shape}, passed shape: {encoder_input_ids.shape}. + For more infromation, please refer to the Whisper preprocessor config.""" dynamic_shapes = None + elif isinstance(self.full_model, T5ForConditionalGeneration): + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}} else: - seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length) - dynamic_shapes = {"input_ids": {1: seq_len_dim}} + raise ValueError( + f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule encoder export." + ) # Export the encoder with torch.no_grad(): @@ -273,7 +291,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi if isinstance(self.full_model, WhisperForConditionalGeneration): dynamic_shapes = None - else: + elif isinstance(self.full_model, T5ForConditionalGeneration): # Define dynamic dimension for encoder output sequence length encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) dynamic_shapes = { @@ -281,6 +299,10 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi "encoder_hidden_states": {1: encoder_seq_len_dim}, "cache_position": None, } + else: + raise ValueError( + f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule decoder export." + ) # Export the decoder with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): @@ -302,7 +324,7 @@ def export( ) -> Dict[str, ExportedProgram]: if encoder_input_ids is None: if isinstance(self.full_model, WhisperForConditionalGeneration): - example_encoder_input_ids = torch.rand((1, 80, 3000)) + example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape) else: example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long) else: diff --git a/tests/models/test_modeling_whisper.py b/tests/models/test_modeling_whisper.py index 1e54ba6b..08d4fd21 100644 --- a/tests/models/test_modeling_whisper.py +++ b/tests/models/test_modeling_whisper.py @@ -20,11 +20,11 @@ import unittest import pytest -import torch +from datasets import load_dataset from executorch.extension.pybindings.portable_lib import ExecuTorchModule from optimum.executorch import ExecuTorchModelForSpeechSeq2Seq -from transformers import AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer from transformers.testing_utils import slow @@ -53,6 +53,7 @@ def test_whisper_transcription(self): model_id = "openai/whisper-tiny" tokenizer = AutoTokenizer.from_pretrained(model_id) model = ExecuTorchModelForSpeechSeq2Seq.from_pretrained(model_id, recipe="xnnpack") + processor = AutoProcessor.from_pretrained(model_id) self.assertIsInstance(model, ExecuTorchModelForSpeechSeq2Seq) self.assertTrue(hasattr(model, "encoder")) @@ -60,12 +61,17 @@ def test_whisper_transcription(self): self.assertTrue(hasattr(model, "decoder")) self.assertIsInstance(model.decoder, ExecuTorchModule) - # Set manual seed for reproducibility, Whisper could possibly hallucinate tokens - # in some cases if this is not set. - torch.manual_seed(11) - input_features = torch.rand(1, 80, 3000) - generated_transcription = model.transcribe(tokenizer, input_features) - expected_text = "" + dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") + sample = dataset[0]["audio"] + + input_features = processor( + sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"] + ).input_features + # Current implementation of the transcibe method accepts up to 30 seconds of audio, therefore I trim the audio here. + input_features_trimmed = input_features[:, :, :3000].contiguous() + + generated_transcription = model.transcribe(tokenizer, input_features_trimmed) + expected_text = " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all, and can discover that." logging.info( f"\nExpected transcription:\n\t{expected_text}\nGenerated transcription:\n\t{generated_transcription}" ) From e40b0d10ccf14b8a4145ac9cf5f7fd62cf4765e0 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 8 Apr 2025 20:49:34 +0200 Subject: [PATCH 12/19] deps: add librosa and soundfile to test dependencies --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index eef1eb6b..6d4eed24 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,8 @@ "pytest", "safetensors", "sentencepiece", + "soundfile", + "librosa", ] From f32e7ac7719cb020d362cdf43aa840677d69b8d8 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 9 Apr 2025 09:50:08 +0200 Subject: [PATCH 13/19] deps: force numba versions other than 0.58.0 --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6d4eed24..b04f20a4 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,9 @@ "pytest", "safetensors", "sentencepiece", - "soundfile", + "numba!=0.58.0", "librosa", + "soundfile", ] From d6c17188172e8e98c8f39f7b6c84b32639acb0e8 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 9 Apr 2025 10:48:40 +0200 Subject: [PATCH 14/19] lint: run ruff check --- optimum/exporters/executorch/__main__.py | 4 ++-- optimum/exporters/executorch/utils.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/executorch/__main__.py b/optimum/exporters/executorch/__main__.py index df30a1b1..adc9ef8f 100644 --- a/optimum/exporters/executorch/__main__.py +++ b/optimum/exporters/executorch/__main__.py @@ -20,10 +20,10 @@ from pathlib import Path from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from transformers import PretrainedConfig -from transformers.utils import is_torch_available from optimum.utils.import_utils import is_transformers_version +from transformers import PretrainedConfig +from transformers.utils import is_torch_available from ...commands.export.executorch import parse_args_executorch from .convert import export_to_executorch diff --git a/optimum/exporters/executorch/utils.py b/optimum/exporters/executorch/utils.py index 725ff02f..cf758006 100644 --- a/optimum/exporters/executorch/utils.py +++ b/optimum/exporters/executorch/utils.py @@ -15,6 +15,7 @@ from typing import Optional import torch + from transformers import GenerationConfig, PretrainedConfig From 2a3df8a919eed2140eb71349bd38403716109255 Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Wed, 9 Apr 2025 11:04:31 -0700 Subject: [PATCH 15/19] Update test_models.yml Resolve conflict --- .github/workflows/test_models.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 79df00b3..1a7892b3 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -41,7 +41,7 @@ jobs: - t5 - vit - whisper - executorch-version: ['0.4.0', '0.6.0rc', 'nightly'] + executorch-version: ['0.4.0', 'nightly'] python-version: ['3.10', '3.11', '3.12'] os: [macos-15] From 09fb2774e3f03d0773b2c14a88a49036c11568ac Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Wed, 9 Apr 2025 11:06:20 -0700 Subject: [PATCH 16/19] Update test_models.yml --- .github/workflows/test_models.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 1a7892b3..79df00b3 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -41,7 +41,7 @@ jobs: - t5 - vit - whisper - executorch-version: ['0.4.0', 'nightly'] + executorch-version: ['0.4.0', '0.6.0rc', 'nightly'] python-version: ['3.10', '3.11', '3.12'] os: [macos-15] From edbb88c39b706277cb05ff7b4fa752b79bed9ad6 Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Wed, 9 Apr 2025 11:10:50 -0700 Subject: [PATCH 17/19] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9fbb6675..ce24db7a 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "pytest", "safetensors", "sentencepiece", - "numba!=0.58.0", + "numba!=0.58.0", # Due to the bug https://github.com/numba/numba/issues/9209 "librosa", "soundfile", ] From e4b053b1ea7012d2536c3637f826c11aae522634 Mon Sep 17 00:00:00 2001 From: Jakub Chmura <92989966+chmjkb@users.noreply.github.com> Date: Thu, 10 Apr 2025 10:52:53 +0200 Subject: [PATCH 18/19] Update README.md Co-authored-by: Guang Yang <42389959+guangy10@users.noreply.github.com> --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e7c107ef..76ae69d6 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,7 @@ We currently support a wide range of popular transformer models, including encod πŸš€ Coming more soon... ### Audio Models +#### Encoder-decoder models - [Whisper](https://huggingface.co/openai/whisper-tiny): OpenAI's `Whisper` and its variants *πŸ“Œ Note: This list is continuously expanding. As we continue to expand support, more models will be added.* From f396043298a253ad6a70844f2cd04137731b6f00 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Thu, 10 Apr 2025 11:09:37 +0200 Subject: [PATCH 19/19] lint: run make lint --- optimum/executorch/modeling.py | 4 ++-- optimum/exporters/executorch/__main__.py | 4 ++-- optimum/exporters/executorch/integrations.py | 4 ++-- optimum/exporters/executorch/utils.py | 1 - tests/models/test_modeling_whisper.py | 4 ++-- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 8e08af3a..e156ef65 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -24,8 +24,6 @@ import torch from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE - -from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch from transformers import ( AutoModelForCausalLM, AutoModelForImageClassification, @@ -38,6 +36,8 @@ ) from transformers.utils import is_offline_mode +from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch + from ..exporters import TasksManager from ..exporters.executorch import main_export from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel diff --git a/optimum/exporters/executorch/__main__.py b/optimum/exporters/executorch/__main__.py index adc9ef8f..df30a1b1 100644 --- a/optimum/exporters/executorch/__main__.py +++ b/optimum/exporters/executorch/__main__.py @@ -20,11 +20,11 @@ from pathlib import Path from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE - -from optimum.utils.import_utils import is_transformers_version from transformers import PretrainedConfig from transformers.utils import is_torch_available +from optimum.utils.import_utils import is_transformers_version + from ...commands.export.executorch import parse_args_executorch from .convert import export_to_executorch from .task_registry import discover_tasks, task_registry diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 1bdc4ee8..64bbe7dc 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -17,8 +17,6 @@ import torch from torch.export import ExportedProgram from torch.nn.attention import SDPBackend - -from optimum.utils.import_utils import is_transformers_version from transformers import ( AutoProcessor, PreTrainedModel, @@ -28,6 +26,8 @@ ) from transformers.generation.configuration_utils import GenerationConfig +from optimum.utils.import_utils import is_transformers_version + from .utils import save_config_to_constant_methods diff --git a/optimum/exporters/executorch/utils.py b/optimum/exporters/executorch/utils.py index cf758006..725ff02f 100644 --- a/optimum/exporters/executorch/utils.py +++ b/optimum/exporters/executorch/utils.py @@ -15,7 +15,6 @@ from typing import Optional import torch - from transformers import GenerationConfig, PretrainedConfig diff --git a/tests/models/test_modeling_whisper.py b/tests/models/test_modeling_whisper.py index 08d4fd21..f6383823 100644 --- a/tests/models/test_modeling_whisper.py +++ b/tests/models/test_modeling_whisper.py @@ -21,12 +21,12 @@ import pytest from datasets import load_dataset - from executorch.extension.pybindings.portable_lib import ExecuTorchModule -from optimum.executorch import ExecuTorchModelForSpeechSeq2Seq from transformers import AutoProcessor, AutoTokenizer from transformers.testing_utils import slow +from optimum.executorch import ExecuTorchModelForSpeechSeq2Seq + class ExecuTorchModelIntegrationTest(unittest.TestCase): def __init__(self, *args, **kwargs):