Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies for ExecuTorch
# Consolidate torchao nightly version once https://github.com/pytorch/ao/issues/2157 is fixed
run: |
if [ "${{ matrix.executorch-version }}" == "nightly" ]; then
export NIGHTLY_VERSION=dev20250413
export NIGHTLY_VERSION=dev20250501
pip install executorch==0.7.0.${NIGHTLY_VERSION} \
torch==2.8.0.${NIGHTLY_VERSION} \
torchvision==0.22.0.${NIGHTLY_VERSION} \
torchaudio==2.6.0.${NIGHTLY_VERSION} \
torchao==0.11.0.dev20250422 \
--extra-index-url "https://download.pytorch.org/whl/nightly/cpu"
else
pip install executorch==${{ matrix.executorch-version }}
Expand Down
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ generated_text = model.text_generation(
print(generated_text)
```

## Supported Models and Backend
## Supported Models

**Optimum-ExecuTorch** currently supports the following transformer models:

Expand Down Expand Up @@ -166,20 +166,28 @@ We currently support a wide range of popular transformer models, including encod
- [Pvt](https://huggingface.co/Zetatech/pvt-tiny-224): Pyramid Vision Transformer (tiny-sized)
- [Swin](https://huggingface.co/microsoft/swin-tiny-patch4-window7-224): Swin Transformer (tiny-sized)

🚀 More coming soon...

### Audio Models
#### Encoder-decoder models
- [Whisper](https://huggingface.co/openai/whisper-tiny): OpenAI's `Whisper` and its variants

*📌 Note: This list is continuously expanding. As we continue to expand support, more models will be added.*

**Supported Backend:**

Currently, **Optimum-ExecuTorch** supports only the [XNNPACK Backend](https://pytorch.org/executorch/main/backends-xnnpack.html) for efficient CPU execution on mobile devices. Quantization support for XNNPACK is planned to be added shortly.
## Supported Optimizations

### Custom Operators
Supported using [custom SDPA](https://github.com/pytorch/executorch/blob/a4322c71c3a97e79e0454a8223db214b010f1193/extension/llm/README.md?plain=1#L40) with Hugging Face Transformers, boosting performance by 3x compared to default SDPA, based on tests with `HuggingFaceTB/SmolLM2-135M`.

### Backends Delegation
Currently, **Optimum-ExecuTorch** supports the [XNNPACK Backend](https://pytorch.org/executorch/main/backends-xnnpack.html) with [custom SDPA](https://github.com/pytorch/executorch/blob/a4322c71c3a97e79e0454a8223db214b010f1193/extension/llm/README.md?plain=1#L40) for efficient execution on mobile CPUs.

For a comprehensive overview of all backends supported by ExecuTorch, please refer to the [ExecuTorch Backend Overview](https://pytorch.org/executorch/main/backends-overview.html).

### Quantization
We currently support Post-Training Quantization (PTQ) for linear layers using int8 dynamic per-token activations and int4 grouped per-channel weights (aka `8da4w`), as well as int8 channelwise embedding quantization.

🚀 Stay tuned as more optimizations and performance enhancements are coming soon!


## 🛠️ Advanced Usage

Expand Down
16 changes: 16 additions & 0 deletions optimum/commands/export/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def parse_args_executorch(parser):
action="store_true",
help="For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.",
)
required_group.add_argument(
"--qlinear",
required=False,
action="store_true",
help="Quantization config for linear layers. If set, defaults to '8da4w' w/ groupsize 32.",
)
required_group.add_argument(
"--qembedding",
required=False,
action="store_true",
help="Quantization config for embedding. If set, defaults to int8 channelwise.",
)


class ExecuTorchExportCommand(BaseOptimumCLICommand):
Expand All @@ -72,6 +84,10 @@ def run(self):
kwargs = {}
if self.args.use_custom_sdpa:
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
if self.args.qlinear:
kwargs["qlinear"] = self.args.qlinear
if self.args.qembedding:
kwargs["qembedding"] = self.args.qembedding

main_export(
model_name_or_path=self.args.model,
Expand Down
5 changes: 4 additions & 1 deletion optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from transformers.utils import is_offline_mode

from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
from executorch.kernels import quantized # noqa

from ..exporters import TasksManager
from ..exporters.executorch import main_export
Expand Down Expand Up @@ -180,7 +181,9 @@ def _from_pretrained(
local_files_only=local_files_only,
)
model = _load_for_executorch(model_cache_path)
logging.info(f"Loaded model from {model_cache_path}")
logging.info(
f"Loaded model from {model_cache_path} ({os.path.getsize(model_cache_path) / (1024 * 1024):.2f} MB)"
)

return {default_file_name.removesuffix(_PTE_SUFFIX): model}

Expand Down
6 changes: 3 additions & 3 deletions optimum/exporters/executorch/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from .recipe_registry import discover_recipes, recipe_registry


logger = logging.getLogger(__name__)

AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward)


Expand Down Expand Up @@ -82,6 +80,8 @@ def export_to_executorch(
full_path = os.path.join(f"{output_dir}", f"{name}.pte")
with open(full_path, "wb") as f:
prog.write_to_file(f)
logger.info(f"Saved exported program to {full_path}")
logging.info(
f"Saved exported program to {full_path} ({os.path.getsize(full_path) / (1024 * 1024):.2f} MB)"
)

return executorch_progs
24 changes: 20 additions & 4 deletions optimum/exporters/executorch/recipes/xnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
import logging
from typing import Dict, Union

from packaging.version import parse
from tabulate import tabulate
from torch.export import ExportedProgram

from executorch import version as executorch_version
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.devtools.backend_debug import get_delegation_info
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
Expand Down Expand Up @@ -60,7 +64,14 @@ def _lower_to_executorch(
metadata=None,
) -> Dict[str, ExecutorchProgram]:
et_progs = {}
backend_config_dict = {
"extract_delegate_segments": True,
}
if parse(executorch_version.__version__).base_version > "0.6.0":
backend_config_dict["do_quant_fusion_and_const_prop"] = True

for pte_name, exported_program in exported_programs.items():
logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}")
et_progs[pte_name] = to_edge_transform_and_lower(
exported_program,
partitioner=[XnnpackPartitioner()],
Expand All @@ -69,11 +80,16 @@ def _lower_to_executorch(
),
constant_methods=metadata,
).to_executorch(
config=ExecutorchBackendConfig(
extract_delegate_segments=True,
),
config=ExecutorchBackendConfig(**backend_config_dict),
)
logging.debug(
f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}"
)
delegation_info = get_delegation_info(et_progs[pte_name].exported_program().graph_module)
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
logging.debug(
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
)
logging.debug(f"Exported program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}")
return et_progs

exported_progs = model.export()
Expand Down
49 changes: 49 additions & 0 deletions optimum/exporters/executorch/tasks/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch
import torchao
from packaging.version import parse
from transformers import AutoModelForCausalLM, GenerationConfig

from ..integrations import CausalLMExportableModule
Expand Down Expand Up @@ -71,4 +76,48 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
},
),
)

# TODO: Move quantization recipe out for better composability.
# TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
qlinear_config = kwargs.get("qlinear", None)
qembedding_config = kwargs.get("qembedding", None)
if qlinear_config or qembedding_config:
# TODO: Update torchao to use 0.11.0 once released
if parse(torchao.__version__) < parse("0.11.0.dev0"):
raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.")

from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass

if qembedding_config:
logging.info("Quantizing embedding layers.")
# TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
embedding_config = IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
)
quantize_(
eager_model,
embedding_config,
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)

if qlinear_config:
logging.info("Quantizing linear layers.")
linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
)
quantize_(
eager_model,
linear_config,
)

unwrap_tensor_subclass(eager_model)

return CausalLMExportableModule(eager_model)
41 changes: 41 additions & 0 deletions tests/models/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import unittest

import pytest
import torchao
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
from packaging.version import parse
from transformers import AutoTokenizer
from transformers.testing_utils import slow

Expand Down Expand Up @@ -167,3 +169,42 @@ def test_gemma3_text_generation_with_custom_sdpa_float16(self):
gc.collect()

self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

@slow
@pytest.mark.run_slow
@pytest.mark.skipif(
parse(torchao.__version__) < parse("0.11.0.dev0"),
reason="Only available on torchao >= 0.11.0.dev0",
)
def test_gemma3_text_generation_with_custom_sdpa_8da4w(self):
# TODO: Until https://github.com/huggingface/optimum/issues/2127 is fixed, have to use non-gated model on CI
# model_id = "google/gemma-3-1b-it"
model_id = "unsloth/gemma-3-1b-it"
prompt = "Write a poem about a machine learning."

# ExecuTorch model + custom sdpa + 8da4w linear quantization
kwargs = {"qlinear": True}
model = ExecuTorchModelForCausalLM.from_pretrained(
model_id,
recipe="xnnpack",
attn_implementation="custom_sdpa",
**kwargs,
)
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
self.assertIsInstance(model.model, ExecuTorchModule)

tokenizer = AutoTokenizer.from_pretrained(model_id)
generated_text = model.text_generation(
tokenizer=tokenizer,
prompt=prompt,
max_seq_len=64,
)
logging.info(f"\nGenerated text:\n\t{generated_text}")
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids

# Free memory before loading eager for quality check
del model
del tokenizer
gc.collect()

self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
38 changes: 38 additions & 0 deletions tests/models/test_modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import unittest

import pytest
import torchao
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
from packaging.version import parse
from transformers import AutoTokenizer
from transformers.testing_utils import slow

Expand Down Expand Up @@ -136,3 +138,39 @@ def test_qwen3_text_generation_with_custom_sdpa_float16(self):
gc.collect()

self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

@slow
@pytest.mark.run_slow
@pytest.mark.skipif(
parse(torchao.__version__) < parse("0.11.0.dev0"),
reason="Only available on torchao >= 0.11.0.dev0",
)
def test_qwen3_text_generation_with_custom_sdpa_8da4w_8we(self):
model_id = "Qwen/Qwen3-0.6B"
prompt = "Give me a short introduction to large language model."
tokenizer = AutoTokenizer.from_pretrained(model_id)

# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
kwargs = {"qlinear": True, "qembedding": True}
model = ExecuTorchModelForCausalLM.from_pretrained(
model_id,
recipe="xnnpack",
attn_implementation="custom_sdpa",
**kwargs,
)
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
self.assertIsInstance(model.model, ExecuTorchModule)
generated_text = model.text_generation(
tokenizer=tokenizer,
prompt=prompt,
max_seq_len=128,
)
logging.info(f"\nGenerated text:\n\t{generated_text}")
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids

# Free memory before loading eager for quality check
del model
del tokenizer
gc.collect()

self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
Loading