Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
- swin
- t5
- vit
- whisper
executorch-version: ['0.4.0', '0.6.0rc', 'nightly']
python-version: ['3.10', '3.11', '3.12']
os: [macos-15]
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ We currently support a wide range of popular transformer models, including encod
🚀 Coming more soon...

### Audio Models
🔊 Coming later
#### Encoder-decoder models
- [Whisper](https://huggingface.co/openai/whisper-tiny): OpenAI's `Whisper` and its variants

*📌 Note: This list is continuously expanding. As we continue to expand support, more models will be added.*

Expand Down
2 changes: 2 additions & 0 deletions optimum/executorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"ExecuTorchModelForImageClassification",
"ExecuTorchModelForMaskedLM",
"ExecuTorchModelForSeq2SeqLM",
"ExecuTorchModelForSpeechSeq2Seq",
],
}

Expand All @@ -32,6 +33,7 @@
ExecuTorchModelForImageClassification,
ExecuTorchModelForMaskedLM,
ExecuTorchModelForSeq2SeqLM,
ExecuTorchModelForSpeechSeq2Seq,
)
else:
import sys
Expand Down
181 changes: 181 additions & 0 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
PretrainedConfig,
PreTrainedTokenizer,
add_start_docstrings,
Expand Down Expand Up @@ -871,3 +872,183 @@ def forward(

def generate(self):
raise NotImplementedError


class ExecuTorchModelForSpeechSeq2Seq(ExecuTorchModelBase):
"""
A SpeechSeq2Seq ExecuTorch model for inference using the ExecuTorch Runtime.

This class provides an interface for loading, running, and generating outputs from a seq2seq language model
optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models
compatible with ExecuTorch runtime.

Attributes:
auto_model_class (`Type`):
Associated Transformers class, `AutoModelForSpeechSeq2Seq`.
model (`ExecuTorchModule`):
The loaded ExecuTorch model.
use_kv_cache (`bool`):
Whether key-value caching is enabled. For performance reasons, the exported model is
optimized to use a static cache.
max_cache_size (`int`):
Maximum sequence length supported by the cache.
max_batch_size (`int`):
Maximum supported batch size.
dtype (`str`):
Data type of the model parameters.
bos_token_id (`int`):
Beginning-of-sequence token ID.
eos_token_id (`int`):
End-of-sequence token ID.
vocab_size (`int`):
Size of the model vocabulary.
"""

auto_model_class = AutoModelForSpeechSeq2Seq

def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
super().__init__(models=models, config=config)
if not hasattr(self, "encoder"):
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
if not hasattr(self, "decoder"):
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
metadata = self.decoder.method_names()
if "use_kv_cache" in metadata:
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
if "get_max_seq_len" in metadata:
self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0]
if "get_max_batch_size" in metadata:
self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0]
if "get_dtype" in metadata:
self.dtype = self.decoder.run_method("get_dtype")[0]
if "get_bos_id" in metadata:
self.bos_token_id = self.decoder.run_method("get_bos_id")[0]
if "get_eos_id" in metadata:
self.eos_token_id = self.decoder.run_method("get_eos_id")[0]
if "get_vocab_size" in metadata:
self.vocab_size = self.decoder.run_method("get_vocab_size")[0]
if "max_hidden_seq_length" in metadata:
self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0]
if "decoder_start_token_id" in metadata:
self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0]

def forward(
self,
input_features: torch.Tensor,
decoder_input_ids: torch.Tensor,
cache_position: torch.Tensor,
encoder_outputs: Optional[torch.Tensor] = None,
):
is_first_prediction = encoder_outputs is None
self.stats.on_model_execution_start()
if is_first_prediction:
encoder_outputs = self.encoder.forward((input_features,))[0]
self.stats.on_prompt_eval_end()

result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs)
self.stats.on_model_execution_end()
return result

def generate(
self,
input_features: torch.Tensor,
echo: bool = False,
pos_base: int = 0,
max_seq_len: Optional[int] = None,
) -> List[int]:
"""
Generate tokens from a prompt using the ExecuTorch model.

Args:
input_features (List[int]):
Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] or
[1, 128, 3000]. For details, check out the processor config.
echo (`bool`, *optional*):
Whether to include prompt tokens in the generated output. Defaults to `False`.
pos_base (`int`, *optional*):
Base position for the prompt tokens. Defaults to 0.
max_seq_len (`int`, *optional*):
Maximum sequence length for the generated output.
Defaults to None and uses the model's `max_cache_size` attribute.
Will be truncated to maximal cache size if larger than `max_cache_size`.

Returns:
List[int]: List of generated token IDs.
"""
self.device = torch.device("cpu")
if max_seq_len is None:
# Default to max_cache_size if max_seq_len is not specified
max_seq_len = self.max_cache_size
elif max_seq_len > self.max_cache_size:
logging.warning(
f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size."
)
max_seq_len = self.max_cache_size

if not hasattr(self, "decoder_start_token_id"):
raise AttributeError("'decoder_start_token_id' is missing in the metadata of the PTE.")
decoder_input_ids = torch.tensor([[self.decoder_start_token_id]], dtype=torch.long)
log_mel = input_features
encoder_outputs = None
generated_ids = []
first_token_generated = False

# Generate tokens one by one
for i in range(max_seq_len - 1):
# Run decoder for next token prediction
cache_position = torch.tensor([i], dtype=torch.long)
self.stats.on_sampling_begin()
logits, encoder_outputs = self.forward(log_mel, decoder_input_ids, cache_position, encoder_outputs)
self.stats.on_sampling_end()
if not first_token_generated:
self.stats.on_first_token()
first_token_generated = True
# Get next token
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
generated_ids.append(next_token)
self.stats.set_num_generated_tokens(len(generated_ids) - 1) # Don't count decoder_start_token

# Update input for next iteration
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long)

# Check if EOS token
if next_token == self.eos_token_id:
break

return generated_ids

def transcribe(
self,
tokenizer: "PreTrainedTokenizer",
input_features: torch.Tensor,
echo: bool = True,
max_seq_len: Optional[int] = None,
):
"""
Perform text generation task for a given prompt using the ExecuTorch model.

Args:
tokenizer (`PreTrainedTokenizer`):
The tokenizer used to encode and decode the prompt and output.
input_features (`str`):
Log-mel spectrogram for 30-second audio chunk. Can be obtained using the WhisperProcessor. Should be of shape [1, 80, 3000] or
[1, 128, 3000]. For details, check out the processor config.
echo (`bool`, *optional*):
Whether to include prompt tokens in the generated output. Defaults to `True`.
max_seq_len (`int`, *optional*):
Maximum sequence length for the generated output.
Defaults to None and uses the model's `max_cache_size` attribute.
Will be truncated to maximal cache size if larger than `max_cache_size`.
"""
self.tokenizer = tokenizer

self.stats.reset()
self.stats.on_inference_start()
generated_tokens = self.generate(
input_features=input_features,
echo=echo,
max_seq_len=max_seq_len,
)
self.stats.on_inference_end()
self.stats.print_report()
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
99 changes: 70 additions & 29 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
import torch
from torch.export import ExportedProgram
from torch.nn.attention import SDPBackend
from transformers import PreTrainedModel, StaticCache
from transformers import (
AutoProcessor,
PreTrainedModel,
StaticCache,
T5ForConditionalGeneration,
WhisperForConditionalGeneration,
)
from transformers.generation.configuration_utils import GenerationConfig

from optimum.utils.import_utils import is_transformers_version
Expand Down Expand Up @@ -153,7 +159,7 @@ def __init__(self, encoder_model):
self.config = encoder_model.config

def forward(self, input_ids):
return self.encoder(input_ids=input_ids).last_hidden_state
return self.encoder(input_ids).last_hidden_state


class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
Expand All @@ -168,7 +174,10 @@ def __init__(self, model, max_static_cache_length, batch_size):

# Get the decoder component
self.decoder = model.get_decoder()
self.lm_head = model.lm_head
if isinstance(model, WhisperForConditionalGeneration):
self.proj_out = model.proj_out
else:
self.proj_out = model.lm_head
self.config = model.config

# Initialize static cache
Expand All @@ -195,10 +204,9 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
cache_position=cache_position,
)

# Apply language model head
lm_logits = self.lm_head(outputs[0])

return lm_logits
# Apply linear projection (lm head) to obtain logits
logits = self.proj_out(outputs[0])
return logits


class Seq2SeqLMExportableModule(torch.nn.Module):
Expand All @@ -225,6 +233,15 @@ def __init__(
"max_cache_len": max_cache_length,
},
)
if isinstance(self.full_model, WhisperForConditionalGeneration):
self._processor = AutoProcessor.from_pretrained(model.config._name_or_path)
self._expected_encoder_input_shape = torch.Size(
(
1,
self._processor.feature_extractor.feature_size,
self._processor.feature_extractor.nb_max_frames,
)
)
additional_configs = {}
additional_configs["max_hidden_seq_length"] = max_hidden_seq_length
# Metadata to be recorded in the pte model file
Expand All @@ -240,14 +257,25 @@ def _export_encoder(self, encoder_input_ids):
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()

# Define dynamic sequence length for encoder
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
if isinstance(self.full_model, WhisperForConditionalGeneration):
assert (
encoder_input_ids.shape == self._expected_encoder_input_shape
), f"""This version of Whisper only accepts encoder input of shape {self._expected_encoder_input_shape}, passed shape: {encoder_input_ids.shape}.
For more infromation, please refer to the Whisper preprocessor config."""
dynamic_shapes = None
elif isinstance(self.full_model, T5ForConditionalGeneration):
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}}
else:
raise ValueError(
f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule encoder export."
)

# Export the encoder
with torch.no_grad():
exported_encoder = torch.export.export(
wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
wrapped_encoder, (encoder_input_ids,), dynamic_shapes=dynamic_shapes, strict=True
)

return exported_encoder

def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
Expand All @@ -261,19 +289,27 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi
.eval()
)

# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
if isinstance(self.full_model, WhisperForConditionalGeneration):
dynamic_shapes = None
elif isinstance(self.full_model, T5ForConditionalGeneration):
# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
dynamic_shapes = {
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
}
else:
raise ValueError(
f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule decoder export."
)

# Export the decoder
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
exported_decoder = torch.export.export(
wrapped_decoder,
(decoder_input_ids, encoder_hidden_states, cache_position),
dynamic_shapes={
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)

Expand All @@ -286,21 +322,26 @@ def export(
encoder_hidden_states=None,
cache_position=None,
) -> Dict[str, ExportedProgram]:
example_encoder_input_ids = (
encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long)
)
if encoder_input_ids is None:
if isinstance(self.full_model, WhisperForConditionalGeneration):
example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape)
else:
example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long)
else:
example_encoder_input_ids = encoder_input_ids

self.exported_encoder = self._export_encoder(example_encoder_input_ids)

if not encoder_hidden_states:
example_encoder_hidden_states = self.exported_encoder.module()(example_encoder_input_ids)
else:
example_encoder_hidden_states = encoder_hidden_states

example_decoder_input_ids = (
decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long)
) # Start token
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
example_encoder_hidden_states = (
encoder_hidden_states
if encoder_hidden_states is not None
else torch.zeros(
(self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32
)
)
self.exported_encoder = self._export_encoder(example_encoder_input_ids)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)

self.exported_decoder = self._export_decoder(
example_decoder_input_ids, example_encoder_hidden_states, example_cache_position
)
Expand Down
Loading
Loading