Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,14 @@ def forward(
torch.Tensor: Logits output from the model.
"""
self.stats.on_model_execution_start()
logits = self.model.forward((input_ids, cache_position))[0]

try:
logits = self.model.forward((input_ids, cache_position))[0]
except Exception as e:
shapes = {name: val.shape for name, val in locals().items() if hasattr(val, "shape")}
print(f"Exception: {e}.\n{self.model.method_meta('forward')}\narg shapes: {shapes}")
Comment on lines +631 to +632
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this supposed to do

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print the method_meta and all input args on error, making it easier to debug the issue. For example, https://github.com/huggingface/optimum-executorch/actions/runs/15889548273/job/44809292056#step:5:424, you can easily tell what went wrong.

raise

self.stats.on_model_execution_end()
return logits

Expand Down Expand Up @@ -667,20 +674,31 @@ def generate(
)
max_seq_len = self.max_cache_size
generated_tokens = []
seq_len = self.model.method_meta("forward").input_tensor_meta(1).sizes()[0]

# prefill
for i, prompt_token in enumerate(prompt_tokens):
if seq_len > 1:
# The model is exported with dynamic shapes. Can support parallel prefill.
self.stats.on_sampling_begin()
logits = self.forward(
input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0),
cache_position=torch.tensor([i], dtype=torch.long, device=self.device),
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device),
)
self.stats.on_sampling_end()

next_token = torch.argmax(logits, dim=-1)[0, -1].item()
else:
# Sequential prefill is preserved for backwards compatibility in order to run PTE generated w/o dynamic shapes.
# TODO: We can remove this block once the executorch runtime supports `cache_position`.
for i, prompt_token in enumerate(prompt_tokens):
self.stats.on_sampling_begin()
logits = self.forward(
input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0),
cache_position=torch.tensor([i], dtype=torch.long, device=self.device),
)
self.stats.on_sampling_end()
next_token = torch.argmax(logits, dim=-1).item()
self.stats.on_prompt_eval_end()
first_token_generated = False

next_token = torch.argmax(logits, dim=-1).item()
generated_tokens = prompt_tokens + [next_token]

while len(generated_tokens) < max_seq_len:
Expand Down
73 changes: 62 additions & 11 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict

import torch
from packaging.version import parse
from torch.export import ExportedProgram
from torch.nn.attention import SDPBackend
from transformers import (
Expand Down Expand Up @@ -45,6 +47,46 @@ def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False):
self.use_custom_kv_cache = use_custom_kv_cache
self.use_custom_sdpa = use_custom_sdpa
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
logging.info(f"Metadata to be recorded in PTE: {self.metadata}")

def _prepare_export_inputs(self):
"""
Prepare example inputs and configurations for export.

Returns:
example_input_ids (torch.Tensor): Example input IDs tensor.
example_cache_position (torch.Tensor): Example cache position tensor.
dynamic_shapes (dict or None): Dynamic shape specifications for export.
strict (bool): Whether to use strict export mode.
"""
# Default values for legacy or fallback cases
example_input_ids = torch.tensor([[1]], dtype=torch.long)
example_cache_position = torch.tensor([0], dtype=torch.long)
dynamic_shapes = None
strict = True

is_using_hybrid_cache_wo_custom_sdpa_kv_cache = (
hasattr(self.config, "layer_types")
and getattr(self.config, "sliding_window", None) is not None
and not (self.use_custom_kv_cache and self.use_custom_sdpa)
)

if is_transformers_version(">", "4.52.0") and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache:
# Prepare inputs with dynamic shapes
seq_length = 3 # Sequence length > 1 to avoid specialization issues
example_input_ids = torch.zeros((1, seq_length), dtype=torch.long)
example_cache_position = torch.arange(seq_length, dtype=torch.long)
max_seq_len = self.metadata.get("get_max_seq_len")
sliding_window = self.metadata.get("sliding_window", float("inf"))
max_dim = min(max_seq_len, sliding_window) - 1
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim)
dynamic_shapes = {
"input_ids": {1: seq_len_dim},
"cache_position": {0: seq_len_dim},
}
strict = parse(torch.__version__) != parse("2.7.0") # Workaround for PyTorch bug #150994

return example_input_ids, example_cache_position, dynamic_shapes, strict

def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module):
if is_transformers_version(">=", "4.53.0.dev0"):
Expand All @@ -65,19 +107,25 @@ def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module):
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa"

def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgram]:
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)

if is_transformers_version(">=", "4.52.0.dev0"):
def export(
self,
) -> Dict[str, ExportedProgram]:
input_ids, cache_position, dynamic_shapes, strict = self._prepare_export_inputs()
logging.info(
f"Exporting using input_ids({input_ids.shape})={input_ids}, cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}"
)
if is_transformers_version(">", "4.52.0"):
from transformers.integrations.executorch import (
TorchExportableModuleForDecoderOnlyLM,
)

max_batch_size = 1
max_cache_len = 4094
exportable_module = TorchExportableModuleForDecoderOnlyLM(self.model, max_batch_size, max_cache_len)
exportable_module = TorchExportableModuleForDecoderOnlyLM(
self.model,
max_batch_size=1,
max_cache_len=self.metadata.get("get_max_seq_len"),
)
self._register_attention_mask_for_4_53(exportable_module)

if self.use_custom_kv_cache:
from optimum.executorch.attentions.custom_kv_cache import (
replace_with_et_custom_kv_cache,
Expand All @@ -91,7 +139,7 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr
)

with torch.no_grad():
exported_program = exportable_module.export(example_input_ids, example_cache_position)
exported_program = exportable_module.export(input_ids, cache_position, dynamic_shapes, strict)
# Apply RemoveTransposes pass to remove
# any back-to-back transpose ops that are not needed
# e.g. output of update_cache is transposed and
Expand All @@ -103,15 +151,18 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr
mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0]
exported_program = torch.export.export(
mutated_gm,
args=(example_input_ids, example_cache_position),
args=(input_ids, cache_position),
kwargs={},
dynamic_shapes=dynamic_shapes,
strict=strict,
)
else:
# Path to use legacy API, static export only due to pinned transformers version
from transformers.integrations.executorch import (
convert_and_export_with_cache,
)

exported_program = convert_and_export_with_cache(self.model, example_input_ids, example_cache_position)
exported_program = convert_and_export_with_cache(self.model, input_ids, cache_position)

return {"model": exported_program}

Expand Down
6 changes: 2 additions & 4 deletions optimum/exporters/executorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,14 @@ def save_config_to_constant_methods(
"get_vocab_size": getattr(config, "vocab_size", None),
"get_max_batch_size": 1,
"get_max_seq_len": getattr(config, "max_position_embeddings", None),
"use_kv_cache": getattr(generation_config, "use_cache", None),
"sliding_window": getattr(config, "sliding_window", None),
"decoder_start_token_id": getattr(config, "decoder_start_token_id", None),
"use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation,
}

# Safely access fields from generation_config if it exists
if generation_config is not None:
# Get use_cache with default value
use_cache = getattr(generation_config, "use_cache", None)
metadata["use_kv_cache"] = use_cache

# Check for cache_config and its attributes
cache_config = getattr(generation_config, "cache_config", None)
if cache_config is not None:
Expand Down
12 changes: 4 additions & 8 deletions tests/models/test_modeling_phi4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import unittest

import pytest
import torchao
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
from packaging.version import parse
from transformers import AutoConfig, AutoTokenizer
from transformers.testing_utils import slow

Expand Down Expand Up @@ -75,9 +73,8 @@ def test_phi4_text_generation(self):

@slow
@pytest.mark.run_slow
@pytest.mark.skipif(
parse(torchao.__version__) < parse("0.11.0.dev0"),
reason="Only available on torchao >= 0.11.0.dev0",
@pytest.mark.skip(
reason="Require cache_position support in executorch runtime. Re-enable when available.",
)
def test_phi4_text_generation_with_quantized_pte_from_hub(self):
model_id = "pytorch/Phi-4-mini-instruct-8da4w"
Expand Down Expand Up @@ -118,9 +115,8 @@ def test_phi4_text_generation_with_quantized_pte_from_hub(self):

@slow
@pytest.mark.run_slow
@pytest.mark.skipif(
parse(torchao.__version__) < parse("0.11.0.dev0"),
reason="Only available on torchao >= 0.11.0.dev0",
@pytest.mark.skip(
reason="Require cache_position support in executorch runtime. Re-enable when available.",
)
def test_phi4_text_generation_with_quantized_ckp(self):
model_id = "pytorch/Phi-4-mini-instruct-8da4w"
Expand Down