Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions optimum/executorch/attentions/custom_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

import torch
from executorch import version as executorch_version
from packaging import version as pkg_version


if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa

def custom_sdpa_with_start_pos_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
scaling: Optional[float] = None,
softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# This is before the transpose
max_seq_len = key.shape[2]

# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# Convert the hell out of the inputs to fp32 and back
input_dtype = query.dtype
query = query.to(torch.float32)
key = key.to(torch.float32)
value = value.to(torch.float32)

# Ignore the causal flag from kwargs but use the one in module
kwargs.pop("is_causal", None)

# Calculate the input pos from attention mask.
# Branch out for float vs bool mask
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
attention_mask = attention_mask.reshape(-1, max_seq_len)
first_row_mask = attention_mask[0, :]
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
start_pos = torch.argmin(first_row_mask).item() - 1
output = torch.ops.llama.custom_sdpa(
query,
key,
value,
start_pos=start_pos,
attn_mask=None,
drpout_p=0.0,
is_causal=module.is_causal,
scale=scaling,
)
return output.to(input_dtype), None
2 changes: 2 additions & 0 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
self.eos_token_id = self.model.run_method("get_eos_id")[0]
if "get_vocab_size" in metadata:
self.vocab_size = self.model.run_method("get_vocab_size")[0]
if "use_sdpa_with_kv_cache" in metadata:
self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0]

def forward(
self,
Expand Down
12 changes: 12 additions & 0 deletions optimum/exporters/executorch/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,21 @@
from pathlib import Path
from typing import Union

from packaging import version as pkg_version
from transformers.modeling_utils import AttentionInterface

from executorch import version as executorch_version

from .recipe_registry import discover_recipes, recipe_registry


if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward

# Register custom sdpa via `AttentionInterface` for executorch>=0.6.0
AttentionInterface.register("executorch_custom_sdpa", custom_sdpa_with_start_pos_forward)


logger = logging.getLogger(__name__)


Expand Down
10 changes: 10 additions & 0 deletions optimum/exporters/executorch/recipes/xnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,14 @@ def _lower_to_executorch(
return et_progs

exported_progs = model.export()

if model.config._attn_implementation == "executorch_custom_sdpa":
# Sanity check to make sure the exported program contains the custom sdpa operator.
if not any(
node.op == "call_function" and "custom_sdpa" in str(node.target)
for exported_program in exported_progs.values()
for node in exported_program.graph_module.graph.nodes
):
raise ValueError("'custom_sdpa' not found in the graph.")

return _lower_to_executorch(exported_progs, model.metadata)
1 change: 1 addition & 0 deletions optimum/exporters/executorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def save_config_to_constant_methods(
"get_max_batch_size": 1,
"get_max_seq_len": getattr(config, "max_position_embeddings", None),
"decoder_start_token_id": getattr(config, "decoder_start_token_id", None),
"use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation,
}

# Safely access fields from generation_config if it exists
Expand Down
47 changes: 47 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import subprocess
import tempfile
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory

import torch
from executorch import version as executorch_version
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
from huggingface_hub import HfApi
from packaging import version as pkg_version
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
)

from optimum.executorch import ExecuTorchModelForCausalLM
from optimum.executorch.modeling import _FILE_PATTERN
from optimum.exporters.executorch import main_export
from optimum.utils.file_utils import find_files_matching_pattern

from ..utils import check_causal_lm_output_quality


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class ExecuTorchModelIntegrationTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -97,3 +111,36 @@ def test_find_files_matching_pattern(self):
api.snapshot_download(repo_id=model_id, local_dir=local_dir, revision=revision)
pte_files = find_files_matching_pattern(local_dir, pattern=_FILE_PATTERN, revision=revision)
self.assertTrue(len(pte_files) == 0 if revision == "main" else len(pte_files) > 0)

def test_eager_text_generation_with_custom_sdpa(self):
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
self.skipTest(reason="This test requires executorch >= 0.6 to run.")

model_id = "HuggingFaceTB/SmolLM2-135M"
prompt = "My favourite condiment is "
max_seq_len = 32
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Eager model + custom sdpa
cache_implementation = "static"
attn_implementation = "executorch_custom_sdpa"
eager_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_seq_len,
cache_config={
"batch_size": 1,
"max_cache_len": max_seq_len,
},
),
)
self.assertTrue(eager_model.config._attn_implementation, "executorch_custom_sdpa")
eager_inputs = tokenizer(prompt, return_tensors="pt").to(eager_model.device)
eager_generated_ids = eager_model.generate(**eager_inputs, max_new_tokens=max_seq_len, temperature=0)
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0]
logging.info(f"\nEager generated text:\n\t{eager_generated_text}")
self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_ids))
39 changes: 39 additions & 0 deletions tests/models/test_modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import unittest

import pytest
from executorch import version as executorch_version
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
from packaging import version as pkg_version
from transformers import AutoTokenizer
from transformers.testing_utils import slow

Expand All @@ -30,6 +32,9 @@
from ..utils import check_causal_lm_output_quality


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class ExecuTorchModelIntegrationTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -71,3 +76,37 @@ def test_qwen2_5_text_generation(self):
gc.collect()

self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

@slow
@pytest.mark.run_slow
def test_qwen2_5_text_generation_with_custom_sdpa(self):
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
self.skipTest(reason="This test requires executorch >= 0.6 to run.")

model_id = "Qwen/Qwen2.5-0.5B"
prompt = "My favourite condiment is "
max_seq_len = 32
tokenizer = AutoTokenizer.from_pretrained(model_id)

# ExecuTorch model + custom sdpa
model = ExecuTorchModelForCausalLM.from_pretrained(
model_id,
recipe="xnnpack",
attn_implementation="executorch_custom_sdpa",
)
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
self.assertIsInstance(model.model, ExecuTorchModule)
generated_text = model.text_generation(
tokenizer=tokenizer,
prompt=prompt,
max_seq_len=max_seq_len,
)
logging.info(f"\nGenerated text:\n\t{generated_text}")
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids

# Free memory before loading eager for quality check
del model
del tokenizer
gc.collect()

self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
40 changes: 40 additions & 0 deletions tests/models/test_modeling_smollm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import unittest

import pytest
from executorch import version as executorch_version
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
from packaging import version as pkg_version
from transformers import AutoTokenizer
from transformers.testing_utils import slow

Expand All @@ -30,6 +32,9 @@
from ..utils import check_causal_lm_output_quality


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class ExecuTorchModelIntegrationTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -71,3 +76,38 @@ def test_smollm_text_generation(self):
gc.collect()

self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

@slow
@pytest.mark.run_slow
def test_smollm_text_generation_with_custom_sdpa(self):
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
self.skipTest(reason="This test requires executorch >= 0.6 to run.")

model_id = "HuggingFaceTB/SmolLM2-135M"
prompt = "My favourite condiment is "
max_seq_len = 32
tokenizer = AutoTokenizer.from_pretrained(model_id)

# ExecuTorch model + custom sdpa
model = ExecuTorchModelForCausalLM.from_pretrained(
model_id,
recipe="xnnpack",
attn_implementation="executorch_custom_sdpa",
)
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
self.assertIsInstance(model.model, ExecuTorchModule)

generated_text = model.text_generation(
tokenizer=tokenizer,
prompt=prompt,
max_seq_len=max_seq_len,
)
logging.info(f"\nGenerated text:\n\t{generated_text}")
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids

# Free memory before loading eager for quality check
del model
del tokenizer
gc.collect()

self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))