Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,19 @@ from optimum.executorch import ExecuTorchModelForCausalLM
from transformers import AutoTokenizer

# Load and export the model on-the-fly
model_id = "meta-llama/Llama-3.2-1B"
model = ExecuTorchModelForCausalLM.from_pretrained(model_id, recipe="xnnpack")
model_id = "HuggingFaceTB/SmolLM2-135M"
model = ExecuTorchModelForCausalLM.from_pretrained(
model_id,
recipe="xnnpack",
attn_implementation="custom_sdpa", # Use custom SDPA implementation for better performance
)

# Generate text right away
tokenizer = AutoTokenizer.from_pretrained(model_id)
generated_text = model.text_generation(
tokenizer=tokenizer,
prompt="Simply put, the theory of relativity states that",
max_seq_len=128
max_seq_len=32,
)
print(generated_text)
```
Expand All @@ -99,10 +103,11 @@ print(generated_text)
Use the CLI tool to convert your model to ExecuTorch format:
```
optimum-cli export executorch \
--model "meta-llama/Llama-3.2-1B" \
--model "HuggingFaceTB/SmolLM2-135M" \
--task "text-generation" \
--recipe "xnnpack" \
--output_dir="meta_llama3_2_1b"
--output_dir="hf_smollm2" \
--use_custom_sdpa
```

#### Step 2: Load and run inference
Expand All @@ -112,14 +117,14 @@ from optimum.executorch import ExecuTorchModelForCausalLM
from transformers import AutoTokenizer

# Load the exported model
model = ExecuTorchModelForCausalLM.from_pretrained("./meta_llama3_2_1b")
model = ExecuTorchModelForCausalLM.from_pretrained("./hf_smollm2")

# Initialize tokenizer and generate text
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
generated_text = model.text_generation(
tokenizer=tokenizer,
prompt="Simply put, the theory of relativity states that",
max_seq_len=128
max_seq_len=32
)
print(generated_text)
```
Expand Down
26 changes: 13 additions & 13 deletions docs/source/guides/export.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ performance.
Exporting a PyTorch model to ExecuTorch is as simple as

```bash
optimum-cli export executorch --model meta-llama/Llama-3.2-1B --task text-generation --recipe xnnpack --output_dir meta_llama3_2_1b
optimum-cli export executorch \
--model HuggingFaceTB/SmolLM2-135M \
--task text-generation \
--recipe xnnpack \
--output_dir hf_smollm2 \
--use_custom_sdpa
```

Check out the help for more options:
Expand Down Expand Up @@ -68,13 +73,14 @@ Required arguments:
classification', 'text2text-generation', 'token-classification', 'object-detection', 'zero-shot-object-detection', 'zero-shot-image-classification', 'text-
generation', 'fill-mask'].
--recipe RECIPE Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".
--use_custom_sdpa For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.

```

You should see a `model.pte` file is stored under "./meta_llama3_2_1b/":
You should see a `model.pte` file is stored under "./hf_smollm2/":

```bash
meta_llama3_2_1b/
hf_smollm2/
└── model.pte
```

Expand All @@ -87,16 +93,10 @@ For example, we can load and run the model with [ExecuTorch Runtime](https://pyt
from transformers import AutoTokenizer
from optimum.executorch import ExecuTorchModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = ExecuTorchModelForCausalLM.from_pretrained("meta_llama3_2_1b/")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
model = ExecuTorchModelForCausalLM.from_pretrained("hf_smollm2/")
prompt = "Simply put, the theory of relativity states that"
generated_text = model.text_generation(tokenizer=tokenizer, prompt=prompt, max_seq_len=45)
```

Printing the `generated_text` would give that:

```
"Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference. In other words, the laws of physics are the same in all inertial frames of reference."
print(f"\nGenerated texts:\n\t{model.text_generation(tokenizer=tokenizer, prompt=prompt, max_seq_len=45)}")
```

As you can see, converting a model to ExecuTorch does not mean leaving the Hugging Face ecosystem. You end up with a similar API as regular 🤗 Transformers models!
Expand All @@ -106,5 +106,5 @@ In case your model wasn't already exported to ExecuTorch, it can also be convert
```python
from optimum.executorch import ExecuTorchModelForCausalLM

model = ExecuTorchModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", recipe="xnnpack")
model = ExecuTorchModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", recipe="xnnpack", attn_implementation="custom_sdpa")
```
16 changes: 16 additions & 0 deletions optimum/commands/export/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from pathlib import Path
from typing import TYPE_CHECKING

from executorch import version as executorch_version
from packaging import version as pkg_version

from ...exporters import TasksManager
from ..base import BaseOptimumCLICommand, CommandInfo

Expand Down Expand Up @@ -51,6 +54,12 @@ def parse_args_executorch(parser):
default="xnnpack",
help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".',
)
required_group.add_argument(
"--use_custom_sdpa",
required=False,
action="store_true",
help="For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.",
)


class ExecuTorchExportCommand(BaseOptimumCLICommand):
Expand All @@ -63,9 +72,16 @@ def parse_args(parser: "ArgumentParser"):
def run(self):
from ...exporters.executorch import main_export

kwargs = {}
if self.args.use_custom_sdpa:
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
raise ValueError("custom_sdpa is not supported for executorch < 0.6.0")
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa

main_export(
model_name_or_path=self.args.model,
task=self.args.task,
recipe=self.args.recipe,
output_dir=self.args.output_dir,
**kwargs,
)
70 changes: 70 additions & 0 deletions optimum/executorch/attentions/custom_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

import torch
from executorch import version as executorch_version
from packaging import version as pkg_version


if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa

def custom_sdpa_with_start_pos_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
scaling: Optional[float] = None,
softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# This is before the transpose
max_seq_len = key.shape[2]

# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# Convert the hell out of the inputs to fp32 and back
input_dtype = query.dtype
query = query.to(torch.float32)
key = key.to(torch.float32)
value = value.to(torch.float32)

# Ignore the causal flag from kwargs but use the one in module
kwargs.pop("is_causal", None)

# Calculate the input pos from attention mask.
# Branch out for float vs bool mask
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
attention_mask = attention_mask.reshape(-1, max_seq_len)
first_row_mask = attention_mask[0, :]
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
start_pos = torch.argmin(first_row_mask).item() - 1
output = torch.ops.llama.custom_sdpa(
query,
key,
value,
start_pos=start_pos,
attn_mask=None,
drpout_p=0.0,
is_causal=module.is_causal,
scale=scaling,
)
return output.to(input_dtype), None
2 changes: 2 additions & 0 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
self.eos_token_id = self.model.run_method("get_eos_id")[0]
if "get_vocab_size" in metadata:
self.vocab_size = self.model.run_method("get_vocab_size")[0]
if "use_sdpa_with_kv_cache" in metadata:
self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0]

def forward(
self,
Expand Down
12 changes: 12 additions & 0 deletions optimum/exporters/executorch/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,21 @@
from pathlib import Path
from typing import Union

from packaging import version as pkg_version
from transformers.modeling_utils import AttentionInterface

from executorch import version as executorch_version

from .recipe_registry import discover_recipes, recipe_registry


if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward

# Register custom sdpa via `AttentionInterface` for executorch>=0.6.0
AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward)


logger = logging.getLogger(__name__)


Expand Down
10 changes: 10 additions & 0 deletions optimum/exporters/executorch/recipes/xnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,14 @@ def _lower_to_executorch(
return et_progs

exported_progs = model.export()

if model.config._attn_implementation == "custom_sdpa":
# Sanity check to make sure the exported program contains the custom sdpa operator.
if not any(
node.op == "call_function" and "custom_sdpa" in str(node.target)
for exported_program in exported_progs.values()
for node in exported_program.graph_module.graph.nodes
):
raise ValueError("'custom_sdpa' not found in the graph.")

return _lower_to_executorch(exported_progs, model.metadata)
3 changes: 2 additions & 1 deletion optimum/exporters/executorch/tasks/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
device = "cpu"
batch_size = 1
dtype = kwargs.get("dtype", "float32")
attn_implementation = kwargs.get("attn_implementation", "sdpa")
use_custom_sdpa = kwargs.get("use_custom_sdpa", False)
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
cache_implementation = kwargs.get("cache_implementation", "static")
max_length = kwargs.get("max_length", 2048)
config = kwargs.get("config", None)
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/executorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def save_config_to_constant_methods(
"get_max_batch_size": 1,
"get_max_seq_len": getattr(config, "max_position_embeddings", None),
"decoder_start_token_id": getattr(config, "decoder_start_token_id", None),
"use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation,
}

# Safely access fields from generation_config if it exists
Expand Down
64 changes: 64 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import subprocess
import tempfile
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory

import torch
from executorch import version as executorch_version
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
from huggingface_hub import HfApi
from packaging import version as pkg_version
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
)

from optimum.executorch import ExecuTorchModelForCausalLM
from optimum.executorch.modeling import _FILE_PATTERN
from optimum.exporters.executorch import main_export
from optimum.utils.file_utils import find_files_matching_pattern

from ..utils import check_causal_lm_output_quality


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class ExecuTorchModelIntegrationTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -97,3 +111,53 @@ def test_find_files_matching_pattern(self):
api.snapshot_download(repo_id=model_id, local_dir=local_dir, revision=revision)
pte_files = find_files_matching_pattern(local_dir, pattern=_FILE_PATTERN, revision=revision)
self.assertTrue(len(pte_files) == 0 if revision == "main" else len(pte_files) > 0)

def test_export_with_custom_sdpa(self):
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
self.skipTest(reason="This test requires executorch >= 0.6 to run.")

model_id = "optimum-internal-testing/tiny-random-llama"
with tempfile.TemporaryDirectory() as tempdir:
subprocess.run(
f"optimum-cli export executorch \
--model {model_id} \
--task 'text-generation' \
--recipe 'xnnpack' \
--output_dir {tempdir}/executorch",
shell=True,
check=True,
)
self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))

def test_eager_text_generation_with_custom_sdpa(self):
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
self.skipTest(reason="This test requires executorch >= 0.6 to run.")

model_id = "HuggingFaceTB/SmolLM2-135M"
prompt = "My favourite condiment is "
max_seq_len = 32
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Eager model + custom sdpa
cache_implementation = "static"
attn_implementation = "custom_sdpa"
eager_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_seq_len,
cache_config={
"batch_size": 1,
"max_cache_len": max_seq_len,
},
),
)
self.assertTrue(eager_model.config._attn_implementation, attn_implementation)
eager_inputs = tokenizer(prompt, return_tensors="pt").to(eager_model.device)
eager_generated_ids = eager_model.generate(**eager_inputs, max_new_tokens=max_seq_len, temperature=0)
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0]
logging.info(f"\nEager generated text:\n\t{eager_generated_text}")
self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_ids))
Loading