Skip to content

Commit 8e9e3c2

Browse files
author
Guang Yang
committed
Use custom sdpa for ExecuTorch
1 parent 2f917c3 commit 8e9e3c2

File tree

8 files changed

+168
-3
lines changed

8 files changed

+168
-3
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Tuple, Union
16+
17+
import torch
18+
from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa
19+
20+
21+
def custom_sdpa_with_start_pos_forward(
22+
module: torch.nn.Module,
23+
query: torch.Tensor,
24+
key: torch.Tensor,
25+
value: torch.Tensor,
26+
attention_mask: Union[torch.Tensor, "BlockMask"],
27+
scaling: Optional[float] = None,
28+
softcap: Optional[float] = None,
29+
head_mask: Optional[torch.Tensor] = None,
30+
**kwargs,
31+
) -> Tuple[torch.Tensor, None]:
32+
# This is before the transpose
33+
max_seq_len = key.shape[2]
34+
35+
# FA2 uses non-transposed inputs
36+
query = query.transpose(1, 2)
37+
key = key.transpose(1, 2)
38+
value = value.transpose(1, 2)
39+
40+
# Convert the hell out of the inputs to fp32 and back
41+
input_dtype = query.dtype
42+
query = query.to(torch.float32)
43+
key = key.to(torch.float32)
44+
value = value.to(torch.float32)
45+
46+
# Ignore the causal flag from kwargs but use the one in module
47+
kwargs.pop("is_causal", None)
48+
49+
# Calculate the input pos from attention mask.
50+
# Branch out for float vs bool mask
51+
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
52+
attention_mask = attention_mask.reshape(-1, max_seq_len)
53+
first_row_mask = attention_mask[0, :]
54+
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
55+
start_pos = torch.argmin(first_row_mask).item() - 1
56+
output = torch.ops.llama.custom_sdpa(
57+
query,
58+
key,
59+
value,
60+
start_pos=start_pos,
61+
attn_mask=None,
62+
drpout_p=0.0,
63+
is_causal=module.is_causal,
64+
scale=scaling,
65+
)
66+
return output.to(input_dtype), None

optimum/executorch/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,8 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
594594
self.eos_token_id = self.model.run_method("get_eos_id")[0]
595595
if "get_vocab_size" in metadata:
596596
self.vocab_size = self.model.run_method("get_vocab_size")[0]
597+
if "use_sdpa_with_kv_cache" in metadata:
598+
self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0]
597599

598600
def forward(
599601
self,

optimum/exporters/executorch/convert.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,18 @@
1919
from pathlib import Path
2020
from typing import Union
2121

22+
from transformers.modeling_utils import AttentionInterface
23+
24+
from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward
25+
2226
from .recipe_registry import discover_recipes, recipe_registry
2327

2428

2529
logger = logging.getLogger(__name__)
2630

31+
# Register custom sdpa via `AttentionInterface` unconditionally
32+
AttentionInterface.register("executorch_custom_sdpa", custom_sdpa_with_start_pos_forward)
33+
2734

2835
def export_to_executorch(
2936
model,

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,14 @@ def _lower_to_executorch(
7777
return et_progs
7878

7979
exported_progs = model.export()
80+
81+
if model.config._attn_implementation == "executorch_custom_sdpa":
82+
# Sanity check to make sure the exported program contains the custom sdpa operator.
83+
if not any(
84+
node.op == "call_function" and "custom_sdpa" in str(node.target)
85+
for exported_program in exported_progs.values()
86+
for node in exported_program.graph_module.graph.nodes
87+
):
88+
raise ValueError("'custom_sdpa' not found in the graph.")
89+
8090
return _lower_to_executorch(exported_progs, model.metadata)

optimum/exporters/executorch/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def save_config_to_constant_methods(
4343
"get_max_batch_size": 1,
4444
"get_max_seq_len": getattr(config, "max_position_embeddings", None),
4545
"decoder_start_token_id": getattr(config, "decoder_start_token_id", None),
46+
"use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation,
4647
}
4748

4849
# Safely access fields from generation_config if it exists

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
INSTALL_REQUIRE = [
1515
"optimum~=1.24",
1616
"executorch>=0.4.0,!=0.5.0", # https://github.com/huggingface/optimum-executorch/issues/14
17-
"transformers>=4.46,<=4.50.1",
17+
"transformers==4.51.0",
1818
]
1919

2020
TESTS_REQUIRE = [

tests/models/test_modeling_qwen2.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@
2121

2222
import pytest
2323
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24-
from transformers import AutoTokenizer
24+
from transformers import AutoModelForCausalLM, AutoTokenizer
2525
from transformers.testing_utils import slow
2626

2727
from optimum.executorch import ExecuTorchModelForCausalLM
2828

2929
from ..utils import check_causal_lm_output_quality
3030

3131

32+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
33+
34+
3235
class ExecuTorchModelIntegrationTest(unittest.TestCase):
3336
def __init__(self, *args, **kwargs):
3437
super().__init__(*args, **kwargs)
@@ -63,3 +66,39 @@ def test_qwen2_5_text_generation(self):
6366
)
6467
logging.info(f"\nGenerated text:\n\t{generated_text}")
6568
self.assertTrue(check_causal_lm_output_quality(model_id, generated_text))
69+
70+
@slow
71+
@pytest.mark.run_slow
72+
def test_qwen2_5_text_generation_with_custom_sdpa(self):
73+
model_id = "Qwen/Qwen2.5-0.5B"
74+
prompt = "My favourite condiment is "
75+
max_seq_len = 32
76+
tokenizer = AutoTokenizer.from_pretrained(model_id)
77+
78+
# ExecuTorch model + custom sdpa
79+
model = ExecuTorchModelForCausalLM.from_pretrained(
80+
model_id,
81+
recipe="xnnpack",
82+
attn_implementation="executorch_custom_sdpa",
83+
)
84+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
85+
self.assertIsInstance(model.model, ExecuTorchModule)
86+
generated_text = model.text_generation(
87+
tokenizer=tokenizer,
88+
prompt=prompt,
89+
max_seq_len=max_seq_len,
90+
)
91+
logging.info(f"\nGenerated text:\n\t{generated_text}")
92+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_text))
93+
94+
# Eager model + custom sdpa
95+
eager_model = AutoModelForCausalLM.from_pretrained(
96+
model_id,
97+
attn_implementation="executorch_custom_sdpa",
98+
)
99+
self.assertTrue(eager_model.config._attn_implementation, "executorch_custom_sdpa")
100+
eager_inputs = tokenizer(prompt, return_tensors="pt").to(eager_model.device)
101+
eager_generated_ids = eager_model.generate(**eager_inputs, max_new_tokens=max_seq_len)
102+
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0]
103+
logging.info(f"\nEager generated text:\n\t{eager_generated_text}")
104+
self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_text))

tests/models/test_modeling_smollm.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@
2121

2222
import pytest
2323
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24-
from transformers import AutoTokenizer
24+
from transformers import AutoModelForCausalLM, AutoTokenizer
2525
from transformers.testing_utils import slow
2626

2727
from optimum.executorch import ExecuTorchModelForCausalLM
2828

2929
from ..utils import check_causal_lm_output_quality
3030

3131

32+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
33+
34+
3235
class ExecuTorchModelIntegrationTest(unittest.TestCase):
3336
def __init__(self, *args, **kwargs):
3437
super().__init__(*args, **kwargs)
@@ -63,3 +66,40 @@ def test_smollm_text_generation(self):
6366
)
6467
logging.info(f"\nGenerated text:\n\t{generated_text}")
6568
self.assertTrue(check_causal_lm_output_quality(model_id, generated_text))
69+
70+
@slow
71+
@pytest.mark.run_slow
72+
def test_smollm_text_generation_with_custom_sdpa(self):
73+
model_id = "HuggingFaceTB/SmolLM2-135M"
74+
prompt = "My favourite condiment is "
75+
max_seq_len = 32
76+
tokenizer = AutoTokenizer.from_pretrained(model_id)
77+
78+
# ExecuTorch model + custom sdpa
79+
model = ExecuTorchModelForCausalLM.from_pretrained(
80+
model_id,
81+
recipe="xnnpack",
82+
attn_implementation="executorch_custom_sdpa",
83+
)
84+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
85+
self.assertIsInstance(model.model, ExecuTorchModule)
86+
87+
generated_text = model.text_generation(
88+
tokenizer=tokenizer,
89+
prompt=prompt,
90+
max_seq_len=max_seq_len,
91+
)
92+
logging.info(f"\nGenerated text:\n\t{generated_text}")
93+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_text))
94+
95+
# Eager model + custom sdpa
96+
eager_model = AutoModelForCausalLM.from_pretrained(
97+
model_id,
98+
attn_implementation="executorch_custom_sdpa",
99+
)
100+
self.assertTrue(eager_model.config._attn_implementation, "executorch_custom_sdpa")
101+
eager_inputs = tokenizer(prompt, return_tensors="pt").to(eager_model.device)
102+
eager_generated_ids = eager_model.generate(**eager_inputs, max_new_tokens=max_seq_len)
103+
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0]
104+
logging.info(f"\nEager generated text:\n\t{eager_generated_text}")
105+
self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_text))

0 commit comments

Comments
 (0)