Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
- swin
- t5
- vit
- whisper
executorch-version: ['0.4.0', 'nightly']
python-version: ['3.10', '3.11', '3.12']
os: [macos-15]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ We currently support a wide range of popular transformer models, including encod
🚀 Coming more soon...

### Audio Models
🔊 Coming later
- [Whisper](https://huggingface.co/openai/whisper-tiny): OpenAI's `Whisper` and its variants

*📌 Note: This list is continuously expanding. As we continue to expand support, more models will be added.*

Expand Down
2 changes: 2 additions & 0 deletions optimum/executorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"ExecuTorchModelForImageClassification",
"ExecuTorchModelForMaskedLM",
"ExecuTorchModelForSeq2SeqLM",
"ExecuTorchModelForSpeechSeq2Seq",
],
}

Expand All @@ -32,6 +33,7 @@
ExecuTorchModelForImageClassification,
ExecuTorchModelForMaskedLM,
ExecuTorchModelForSeq2SeqLM,
ExecuTorchModelForSpeechSeq2Seq,
)
else:
import sys
Expand Down
168 changes: 166 additions & 2 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE

from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
from transformers import (
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
PretrainedConfig,
PreTrainedTokenizer,
add_start_docstrings,
)
from transformers.utils import is_offline_mode

from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch

from ..exporters import TasksManager
from ..exporters.executorch import main_export
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
Expand Down Expand Up @@ -871,3 +872,166 @@ def forward(

def generate(self):
raise NotImplementedError


class ExecuTorchModelForSpeechSeq2Seq(ExecuTorchModelBase):
"""
A SpeechSeq2Seq ExecuTorch model for inference using the ExecuTorch Runtime.

This class provides an interface for loading, running, and generating outputs from a seq2seq language model
optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models
compatible with ExecuTorch runtime.

Attributes:
auto_model_class (`Type`):
Associated Transformers class, `AutoModelForSpeechSeq2Seq`.
model (`ExecuTorchModule`):
The loaded ExecuTorch model.
use_kv_cache (`bool`):
Whether key-value caching is enabled. For performance reasons, the exported model is
optimized to use a static cache.
max_cache_size (`int`):
Maximum sequence length supported by the cache.
max_batch_size (`int`):
Maximum supported batch size.
dtype (`str`):
Data type of the model parameters.
bos_token_id (`int`):
Beginning-of-sequence token ID.
eos_token_id (`int`):
End-of-sequence token ID.
vocab_size (`int`):
Size of the model vocabulary.
"""

auto_model_class = AutoModelForSpeechSeq2Seq

def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
super().__init__(models=models, config=config)
if not hasattr(self, "encoder"):
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
if not hasattr(self, "decoder"):
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
metadata = self.decoder.method_names()
if "use_kv_cache" in metadata:
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
if "get_max_seq_len" in metadata:
self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0]
if "get_max_batch_size" in metadata:
self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0]
if "get_dtype" in metadata:
self.dtype = self.decoder.run_method("get_dtype")[0]
if "get_bos_id" in metadata:
self.bos_token_id = self.decoder.run_method("get_bos_id")[0]
if "get_eos_id" in metadata:
self.eos_token_id = self.decoder.run_method("get_eos_id")[0]
if "get_vocab_size" in metadata:
self.vocab_size = self.decoder.run_method("get_vocab_size")[0]
if "max_hidden_seq_length" in metadata:
self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0]
if "decoder_start_token_id" in metadata:
self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0]

def forward(
self,
input_features: torch.Tensor,
decoder_input_ids: torch.Tensor,
cache_position: torch.Tensor,
encoder_outputs: Optional[torch.Tensor] = None,
):
is_first_prediction = encoder_outputs is None
if is_first_prediction:
encoder_outputs = self.encoder.forward((input_features,))[0]

return (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs)

def generate(
self,
input_features: torch.Tensor,
echo: bool = False,
pos_base: int = 0,
max_seq_len: Optional[int] = None,
) -> List[int]:
"""
Generate tokens from a prompt using the ExecuTorch model.

Args:
input_features (List[int]):
Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000]
echo (`bool`, *optional*):
Whether to include prompt tokens in the generated output. Defaults to `False`.
pos_base (`int`, *optional*):
Base position for the prompt tokens. Defaults to 0.
max_seq_len (`int`, *optional*):
Maximum sequence length for the generated output.
Defaults to None and uses the model's `max_cache_size` attribute.
Will be truncated to maximal cache size if larger than `max_cache_size`.

Returns:
List[int]: List of generated token IDs.
"""
self.device = torch.device("cpu")
if max_seq_len is None:
# Default to max_cache_size if max_seq_len is not specified
max_seq_len = self.max_cache_size
elif max_seq_len > self.max_cache_size:
logging.warning(
f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size."
)
max_seq_len = self.max_cache_size

if not hasattr(self, "decoder_start_token_id"):
raise AttributeError("'decoder_start_token_id' is missing in the metadata of the PTE.")
decoder_input_ids = torch.tensor([[self.decoder_start_token_id]], dtype=torch.long)
log_mel = input_features
encoder_outputs = None
generated_ids = []

# Generate tokens one by one
for i in range(max_seq_len - 1):
# Run decoder for next token prediction
cache_position = torch.tensor([i], dtype=torch.long)
logits, encoder_outputs = self.forward(log_mel, decoder_input_ids, cache_position, encoder_outputs)

# Get next token
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
generated_ids.append(next_token)

# Update input for next iteration
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long)

# Check if EOS token
if next_token == self.eos_token_id:
break

return generated_ids

def transcribe(
self,
tokenizer: "PreTrainedTokenizer",
input_features: torch.Tensor,
echo: bool = True,
max_seq_len: Optional[int] = None,
):
"""
Perform text generation task for a given prompt using the ExecuTorch model.

Args:
tokenizer (`PreTrainedTokenizer`):
The tokenizer used to encode and decode the prompt and output.
input_features (`str`):
Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000]
echo (`bool`, *optional*):
Whether to include prompt tokens in the generated output. Defaults to `True`.
max_seq_len (`int`, *optional*):
Maximum sequence length for the generated output.
Defaults to None and uses the model's `max_cache_size` attribute.
Will be truncated to maximal cache size if larger than `max_cache_size`.
"""
self.tokenizer = tokenizer
generated_tokens = self.generate(
input_features=input_features,
echo=echo,
max_seq_len=max_seq_len,
)
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
75 changes: 47 additions & 28 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from transformers.generation.configuration_utils import GenerationConfig

from optimum.utils.import_utils import is_transformers_version
from transformers import PreTrainedModel, StaticCache, WhisperForConditionalGeneration
from transformers.generation.configuration_utils import GenerationConfig

from .utils import save_config_to_constant_methods

Expand Down Expand Up @@ -153,7 +155,7 @@ def __init__(self, encoder_model):
self.config = encoder_model.config

def forward(self, input_ids):
return self.encoder(input_ids=input_ids).last_hidden_state
return self.encoder(input_ids).last_hidden_state


class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
Expand All @@ -168,7 +170,10 @@ def __init__(self, model, max_static_cache_length, batch_size):

# Get the decoder component
self.decoder = model.get_decoder()
self.lm_head = model.lm_head
if isinstance(model, WhisperForConditionalGeneration):
self.proj_out = model.proj_out
else:
self.proj_out = model.lm_head
self.config = model.config

# Initialize static cache
Expand All @@ -195,10 +200,9 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
cache_position=cache_position,
)

# Apply language model head
lm_logits = self.lm_head(outputs[0])

return lm_logits
# Apply linear projection (lm head) to obtain logits
logits = self.proj_out(outputs[0])
return logits


class Seq2SeqLMExportableModule(torch.nn.Module):
Expand Down Expand Up @@ -240,14 +244,20 @@ def _export_encoder(self, encoder_input_ids):
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()

# Define dynamic sequence length for encoder
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
if isinstance(self.full_model, WhisperForConditionalGeneration):
assert encoder_input_ids.shape == torch.Size(
[1, 80, 3000]
), f"Whisper only accepts a log-mel spectrogram of shape [1, 80, 3000], passed shape: {encoder_input_ids.shape}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chmjkb This is the config for whisper-tiny only, right? If I switch to a different variant of whisper, it won't work, e.g. whisper-large-v3. I think we can load this dynamically from preprocessor_config.json?
IIUC, each dim in encoder_input_ids represents [batch_size, feature_size, nb_max_frames]?

Copy link
Contributor Author

@chmjkb chmjkb Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think whisper large is an exception and it will take in 128 features instead of 80. Will fix that. (The smaller ones should work with 80)

Copy link
Contributor Author

@chmjkb chmjkb Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm thinking about getting rid of this assertion. Instead, I could just resolve correct shape when instantiating example_encoder_input_ids (im doing this anyways). If a user passes wrong shape for some reason, then be it.
Also, WhisperEncoder itself raises ValueError when the length of the features is not correct.
WDYT?

dynamic_shapes = None
else:
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
dynamic_shapes = {"input_ids": {1: seq_len_dim}}

# Export the encoder
with torch.no_grad():
exported_encoder = torch.export.export(
wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
wrapped_encoder, (encoder_input_ids,), dynamic_shapes=dynamic_shapes, strict=True
)

return exported_encoder

def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
Expand All @@ -261,19 +271,23 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi
.eval()
)

# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
if isinstance(self.full_model, WhisperForConditionalGeneration):
dynamic_shapes = None
else:
# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
dynamic_shapes = {
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tugsbayasgalan @pianpwk Can we use Dim.AUTO here, to avoid setting dynamic_shapes explicitly for different models? In this case, Whisper would expect all static shapes and T5 would want to set encoder_hidden_seq_length to by dynamic at least.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep Dim.AuTO would be perfect here. Doing so, you don't need the if/else branching.

Copy link
Contributor Author

@chmjkb chmjkb Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried changing the code to the following:

dynamic_shapes = {
    "decoder_input_ids": None,
    "encoder_hidden_states": {1: torch.export.Dim.AUTO},
    "cache_position": None
}

Unfortunately when I do that, the export fails with the following error:
RuntimeError: Cannot evaluate the shape upper bound of a dynamic-shaped tensor to a concrete bounded integer. Got tensor spec: TensorSpec(dtype=torch.float32, shape=[1, s0, 384], layout=torch.strided, is_sparse=False, shape_dynamism=1, const=False, requires_grad=True).The upper bound shape we get [1, int_oo, 384], the upper bound stride we get [int_oo, 384, 1]This tensor could either be from 1. a data-dependent operation such as nonzero. Or 2. an input, whose don't have a constraint for the upper bound.Please use export's constrain_as_size() or constrain_as_value() apis and set a concrete upper bound to resolve this.

However doing:
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}}
for the Whisper encoder seems to work, however it makes the T5 export fail :D


# Export the decoder
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
exported_decoder = torch.export.export(
wrapped_decoder,
(decoder_input_ids, encoder_hidden_states, cache_position),
dynamic_shapes={
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)

Expand All @@ -286,21 +300,26 @@ def export(
encoder_hidden_states=None,
cache_position=None,
) -> Dict[str, ExportedProgram]:
example_encoder_input_ids = (
encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long)
)
if encoder_input_ids is None:
if isinstance(self.full_model, WhisperForConditionalGeneration):
example_encoder_input_ids = torch.rand((1, 80, 3000))
else:
example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long)
else:
example_encoder_input_ids = encoder_input_ids

self.exported_encoder = self._export_encoder(example_encoder_input_ids)

if not encoder_hidden_states:
example_encoder_hidden_states = self.exported_encoder.module()(example_encoder_input_ids)
else:
example_encoder_hidden_states = encoder_hidden_states

example_decoder_input_ids = (
decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long)
) # Start token
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
example_encoder_hidden_states = (
encoder_hidden_states
if encoder_hidden_states is not None
else torch.zeros(
(self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32
)
)
self.exported_encoder = self._export_encoder(example_encoder_input_ids)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)

self.exported_decoder = self._export_decoder(
example_decoder_input_ids, example_encoder_hidden_states, example_cache_position
)
Expand Down
Loading