diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index ccc7528d..ca3d3413 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -624,7 +624,14 @@ def forward( torch.Tensor: Logits output from the model. """ self.stats.on_model_execution_start() - logits = self.model.forward((input_ids, cache_position))[0] + + try: + logits = self.model.forward((input_ids, cache_position))[0] + except Exception as e: + shapes = {name: val.shape for name, val in locals().items() if hasattr(val, "shape")} + print(f"Exception: {e}.\n{self.model.method_meta('forward')}\narg shapes: {shapes}") + raise + self.stats.on_model_execution_end() return logits @@ -667,20 +674,31 @@ def generate( ) max_seq_len = self.max_cache_size generated_tokens = [] + seq_len = self.model.method_meta("forward").input_tensor_meta(1).sizes()[0] - # prefill - for i, prompt_token in enumerate(prompt_tokens): + if seq_len > 1: + # The model is exported with dynamic shapes. Can support parallel prefill. self.stats.on_sampling_begin() logits = self.forward( - input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), - cache_position=torch.tensor([i], dtype=torch.long, device=self.device), + input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device), ) self.stats.on_sampling_end() - + next_token = torch.argmax(logits, dim=-1)[0, -1].item() + else: + # Sequential prefill is preserved for backwards compatibility in order to run PTE generated w/o dynamic shapes. + # TODO: We can remove this block once the executorch runtime supports `cache_position`. + for i, prompt_token in enumerate(prompt_tokens): + self.stats.on_sampling_begin() + logits = self.forward( + input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor([i], dtype=torch.long, device=self.device), + ) + self.stats.on_sampling_end() + next_token = torch.argmax(logits, dim=-1).item() self.stats.on_prompt_eval_end() first_token_generated = False - next_token = torch.argmax(logits, dim=-1).item() generated_tokens = prompt_tokens + [next_token] while len(generated_tokens) < max_seq_len: diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 7f8cb139..23e6819a 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Dict import torch +from packaging.version import parse from torch.export import ExportedProgram from torch.nn.attention import SDPBackend from transformers import ( @@ -45,6 +47,46 @@ def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False): self.use_custom_kv_cache = use_custom_kv_cache self.use_custom_sdpa = use_custom_sdpa self.metadata = save_config_to_constant_methods(model.config, model.generation_config) + logging.info(f"Metadata to be recorded in PTE: {self.metadata}") + + def _prepare_export_inputs(self): + """ + Prepare example inputs and configurations for export. + + Returns: + example_input_ids (torch.Tensor): Example input IDs tensor. + example_cache_position (torch.Tensor): Example cache position tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + # Default values for legacy or fallback cases + example_input_ids = torch.tensor([[1]], dtype=torch.long) + example_cache_position = torch.tensor([0], dtype=torch.long) + dynamic_shapes = None + strict = True + + is_using_hybrid_cache_wo_custom_sdpa_kv_cache = ( + hasattr(self.config, "layer_types") + and getattr(self.config, "sliding_window", None) is not None + and not (self.use_custom_kv_cache and self.use_custom_sdpa) + ) + + if is_transformers_version(">", "4.52.0") and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache: + # Prepare inputs with dynamic shapes + seq_length = 3 # Sequence length > 1 to avoid specialization issues + example_input_ids = torch.zeros((1, seq_length), dtype=torch.long) + example_cache_position = torch.arange(seq_length, dtype=torch.long) + max_seq_len = self.metadata.get("get_max_seq_len") + sliding_window = self.metadata.get("sliding_window", float("inf")) + max_dim = min(max_seq_len, sliding_window) - 1 + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) + dynamic_shapes = { + "input_ids": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + strict = parse(torch.__version__) != parse("2.7.0") # Workaround for PyTorch bug #150994 + + return example_input_ids, example_cache_position, dynamic_shapes, strict def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module): if is_transformers_version(">=", "4.53.0.dev0"): @@ -65,19 +107,25 @@ def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module): # This handles both regular sdpa and one for sliding window/local attention exportable_module.model.model.config._attn_implementation = "custom_sdpa" - def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgram]: - example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) - example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) - - if is_transformers_version(">=", "4.52.0.dev0"): + def export( + self, + ) -> Dict[str, ExportedProgram]: + input_ids, cache_position, dynamic_shapes, strict = self._prepare_export_inputs() + logging.info( + f"Exporting using input_ids({input_ids.shape})={input_ids}, cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}" + ) + if is_transformers_version(">", "4.52.0"): from transformers.integrations.executorch import ( TorchExportableModuleForDecoderOnlyLM, ) - max_batch_size = 1 - max_cache_len = 4094 - exportable_module = TorchExportableModuleForDecoderOnlyLM(self.model, max_batch_size, max_cache_len) + exportable_module = TorchExportableModuleForDecoderOnlyLM( + self.model, + max_batch_size=1, + max_cache_len=self.metadata.get("get_max_seq_len"), + ) self._register_attention_mask_for_4_53(exportable_module) + if self.use_custom_kv_cache: from optimum.executorch.attentions.custom_kv_cache import ( replace_with_et_custom_kv_cache, @@ -91,7 +139,7 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr ) with torch.no_grad(): - exported_program = exportable_module.export(example_input_ids, example_cache_position) + exported_program = exportable_module.export(input_ids, cache_position, dynamic_shapes, strict) # Apply RemoveTransposes pass to remove # any back-to-back transpose ops that are not needed # e.g. output of update_cache is transposed and @@ -103,15 +151,18 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0] exported_program = torch.export.export( mutated_gm, - args=(example_input_ids, example_cache_position), + args=(input_ids, cache_position), kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict, ) else: + # Path to use legacy API, static export only due to pinned transformers version from transformers.integrations.executorch import ( convert_and_export_with_cache, ) - exported_program = convert_and_export_with_cache(self.model, example_input_ids, example_cache_position) + exported_program = convert_and_export_with_cache(self.model, input_ids, cache_position) return {"model": exported_program} diff --git a/optimum/exporters/executorch/utils.py b/optimum/exporters/executorch/utils.py index 8d3efe1d..70447957 100644 --- a/optimum/exporters/executorch/utils.py +++ b/optimum/exporters/executorch/utils.py @@ -42,16 +42,14 @@ def save_config_to_constant_methods( "get_vocab_size": getattr(config, "vocab_size", None), "get_max_batch_size": 1, "get_max_seq_len": getattr(config, "max_position_embeddings", None), + "use_kv_cache": getattr(generation_config, "use_cache", None), + "sliding_window": getattr(config, "sliding_window", None), "decoder_start_token_id": getattr(config, "decoder_start_token_id", None), "use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation, } # Safely access fields from generation_config if it exists if generation_config is not None: - # Get use_cache with default value - use_cache = getattr(generation_config, "use_cache", None) - metadata["use_kv_cache"] = use_cache - # Check for cache_config and its attributes cache_config = getattr(generation_config, "cache_config", None) if cache_config is not None: diff --git a/tests/models/test_modeling_phi4.py b/tests/models/test_modeling_phi4.py index 2c3e2bb1..6a8afe79 100644 --- a/tests/models/test_modeling_phi4.py +++ b/tests/models/test_modeling_phi4.py @@ -19,9 +19,7 @@ import unittest import pytest -import torchao from executorch.extension.pybindings.portable_lib import ExecuTorchModule -from packaging.version import parse from transformers import AutoConfig, AutoTokenizer from transformers.testing_utils import slow @@ -75,9 +73,8 @@ def test_phi4_text_generation(self): @slow @pytest.mark.run_slow - @pytest.mark.skipif( - parse(torchao.__version__) < parse("0.11.0.dev0"), - reason="Only available on torchao >= 0.11.0.dev0", + @pytest.mark.skip( + reason="Require cache_position support in executorch runtime. Re-enable when available.", ) def test_phi4_text_generation_with_quantized_pte_from_hub(self): model_id = "pytorch/Phi-4-mini-instruct-8da4w" @@ -118,9 +115,8 @@ def test_phi4_text_generation_with_quantized_pte_from_hub(self): @slow @pytest.mark.run_slow - @pytest.mark.skipif( - parse(torchao.__version__) < parse("0.11.0.dev0"), - reason="Only available on torchao >= 0.11.0.dev0", + @pytest.mark.skip( + reason="Require cache_position support in executorch runtime. Re-enable when available.", ) def test_phi4_text_generation_with_quantized_ckp(self): model_id = "pytorch/Phi-4-mini-instruct-8da4w"