Skip to content

Commit cadd829

Browse files
author
Guang Yang
committed
Use custom sdpa for ExecuTorch
1 parent 2865126 commit cadd829

File tree

8 files changed

+207
-0
lines changed

8 files changed

+207
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Tuple, Union
16+
17+
import torch
18+
from executorch import version as executorch_version
19+
from packaging import version as pkg_version
20+
21+
22+
if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
23+
from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa
24+
25+
def custom_sdpa_with_start_pos_forward(
26+
module: torch.nn.Module,
27+
query: torch.Tensor,
28+
key: torch.Tensor,
29+
value: torch.Tensor,
30+
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
31+
scaling: Optional[float] = None,
32+
softcap: Optional[float] = None,
33+
head_mask: Optional[torch.Tensor] = None,
34+
**kwargs,
35+
) -> Tuple[torch.Tensor, None]:
36+
# This is before the transpose
37+
max_seq_len = key.shape[2]
38+
39+
# FA2 uses non-transposed inputs
40+
query = query.transpose(1, 2)
41+
key = key.transpose(1, 2)
42+
value = value.transpose(1, 2)
43+
44+
# Convert the hell out of the inputs to fp32 and back
45+
input_dtype = query.dtype
46+
query = query.to(torch.float32)
47+
key = key.to(torch.float32)
48+
value = value.to(torch.float32)
49+
50+
# Ignore the causal flag from kwargs but use the one in module
51+
kwargs.pop("is_causal", None)
52+
53+
# Calculate the input pos from attention mask.
54+
# Branch out for float vs bool mask
55+
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
56+
attention_mask = attention_mask.reshape(-1, max_seq_len)
57+
first_row_mask = attention_mask[0, :]
58+
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
59+
start_pos = torch.argmin(first_row_mask).item() - 1
60+
output = torch.ops.llama.custom_sdpa(
61+
query,
62+
key,
63+
value,
64+
start_pos=start_pos,
65+
attn_mask=None,
66+
drpout_p=0.0,
67+
is_causal=module.is_causal,
68+
scale=scaling,
69+
)
70+
return output.to(input_dtype), None

optimum/executorch/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,8 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
594594
self.eos_token_id = self.model.run_method("get_eos_id")[0]
595595
if "get_vocab_size" in metadata:
596596
self.vocab_size = self.model.run_method("get_vocab_size")[0]
597+
if "use_sdpa_with_kv_cache" in metadata:
598+
self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0]
597599

598600
def forward(
599601
self,

optimum/exporters/executorch/convert.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,21 @@
1919
from pathlib import Path
2020
from typing import Union
2121

22+
from packaging import version as pkg_version
23+
from transformers.modeling_utils import AttentionInterface
24+
25+
from executorch import version as executorch_version
26+
2227
from .recipe_registry import discover_recipes, recipe_registry
2328

2429

30+
if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
31+
from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward
32+
33+
# Register custom sdpa via `AttentionInterface` for executorch>=0.6.0
34+
AttentionInterface.register("executorch_custom_sdpa", custom_sdpa_with_start_pos_forward)
35+
36+
2537
logger = logging.getLogger(__name__)
2638

2739

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,14 @@ def _lower_to_executorch(
7777
return et_progs
7878

7979
exported_progs = model.export()
80+
81+
if model.config._attn_implementation == "executorch_custom_sdpa":
82+
# Sanity check to make sure the exported program contains the custom sdpa operator.
83+
if not any(
84+
node.op == "call_function" and "custom_sdpa" in str(node.target)
85+
for exported_program in exported_progs.values()
86+
for node in exported_program.graph_module.graph.nodes
87+
):
88+
raise ValueError("'custom_sdpa' not found in the graph.")
89+
8090
return _lower_to_executorch(exported_progs, model.metadata)

optimum/exporters/executorch/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def save_config_to_constant_methods(
4343
"get_max_batch_size": 1,
4444
"get_max_seq_len": getattr(config, "max_position_embeddings", None),
4545
"decoder_start_token_id": getattr(config, "decoder_start_token_id", None),
46+
"use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation,
4647
}
4748

4849
# Safely access fields from generation_config if it exists

tests/models/test_modeling_common.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,35 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
1617
import os
1718
import subprocess
1819
import tempfile
1920
import unittest
2021
from pathlib import Path
2122
from tempfile import TemporaryDirectory
2223

24+
import torch
25+
from executorch import version as executorch_version
2326
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
2427
from huggingface_hub import HfApi
28+
from packaging import version as pkg_version
29+
from transformers import (
30+
AutoModelForCausalLM,
31+
AutoTokenizer,
32+
GenerationConfig,
33+
)
2534

2635
from optimum.executorch import ExecuTorchModelForCausalLM
2736
from optimum.executorch.modeling import _FILE_PATTERN
2837
from optimum.exporters.executorch import main_export
2938
from optimum.utils.file_utils import find_files_matching_pattern
3039

40+
from ..utils import check_causal_lm_output_quality
41+
42+
43+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
44+
3145

3246
class ExecuTorchModelIntegrationTest(unittest.TestCase):
3347
def __init__(self, *args, **kwargs):
@@ -97,3 +111,36 @@ def test_find_files_matching_pattern(self):
97111
api.snapshot_download(repo_id=model_id, local_dir=local_dir, revision=revision)
98112
pte_files = find_files_matching_pattern(local_dir, pattern=_FILE_PATTERN, revision=revision)
99113
self.assertTrue(len(pte_files) == 0 if revision == "main" else len(pte_files) > 0)
114+
115+
def test_eager_text_generation_with_custom_sdpa(self):
116+
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
117+
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
118+
119+
model_id = "HuggingFaceTB/SmolLM2-135M"
120+
prompt = "My favourite condiment is "
121+
max_seq_len = 32
122+
tokenizer = AutoTokenizer.from_pretrained(model_id)
123+
124+
# Eager model + custom sdpa
125+
cache_implementation = "static"
126+
attn_implementation = "executorch_custom_sdpa"
127+
eager_model = AutoModelForCausalLM.from_pretrained(
128+
model_id,
129+
torch_dtype=torch.bfloat16,
130+
attn_implementation=attn_implementation,
131+
generation_config=GenerationConfig(
132+
use_cache=True,
133+
cache_implementation=cache_implementation,
134+
max_length=max_seq_len,
135+
cache_config={
136+
"batch_size": 1,
137+
"max_cache_len": max_seq_len,
138+
},
139+
),
140+
)
141+
self.assertTrue(eager_model.config._attn_implementation, "executorch_custom_sdpa")
142+
eager_inputs = tokenizer(prompt, return_tensors="pt").to(eager_model.device)
143+
eager_generated_ids = eager_model.generate(**eager_inputs, max_new_tokens=max_seq_len, temperature=0)
144+
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0]
145+
logging.info(f"\nEager generated text:\n\t{eager_generated_text}")
146+
self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_text))

tests/models/test_modeling_qwen2.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import unittest
2121

2222
import pytest
23+
from executorch import version as executorch_version
2324
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
25+
from packaging import version as pkg_version
2426
from transformers import AutoTokenizer
2527
from transformers.testing_utils import slow
2628

@@ -29,6 +31,9 @@
2931
from ..utils import check_causal_lm_output_quality
3032

3133

34+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
35+
36+
3237
class ExecuTorchModelIntegrationTest(unittest.TestCase):
3338
def __init__(self, *args, **kwargs):
3439
super().__init__(*args, **kwargs)
@@ -63,3 +68,30 @@ def test_qwen2_5_text_generation(self):
6368
)
6469
logging.info(f"\nGenerated text:\n\t{generated_text}")
6570
self.assertTrue(check_causal_lm_output_quality(model_id, generated_text))
71+
72+
@slow
73+
@pytest.mark.run_slow
74+
def test_qwen2_5_text_generation_with_custom_sdpa(self):
75+
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
76+
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
77+
78+
model_id = "Qwen/Qwen2.5-0.5B"
79+
prompt = "My favourite condiment is "
80+
max_seq_len = 32
81+
tokenizer = AutoTokenizer.from_pretrained(model_id)
82+
83+
# ExecuTorch model + custom sdpa
84+
model = ExecuTorchModelForCausalLM.from_pretrained(
85+
model_id,
86+
recipe="xnnpack",
87+
attn_implementation="executorch_custom_sdpa",
88+
)
89+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
90+
self.assertIsInstance(model.model, ExecuTorchModule)
91+
generated_text = model.text_generation(
92+
tokenizer=tokenizer,
93+
prompt=prompt,
94+
max_seq_len=max_seq_len,
95+
)
96+
logging.info(f"\nGenerated text:\n\t{generated_text}")
97+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_text))

tests/models/test_modeling_smollm.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import unittest
2121

2222
import pytest
23+
from executorch import version as executorch_version
2324
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
25+
from packaging import version as pkg_version
2426
from transformers import AutoTokenizer
2527
from transformers.testing_utils import slow
2628

@@ -29,6 +31,9 @@
2931
from ..utils import check_causal_lm_output_quality
3032

3133

34+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
35+
36+
3237
class ExecuTorchModelIntegrationTest(unittest.TestCase):
3338
def __init__(self, *args, **kwargs):
3439
super().__init__(*args, **kwargs)
@@ -63,3 +68,31 @@ def test_smollm_text_generation(self):
6368
)
6469
logging.info(f"\nGenerated text:\n\t{generated_text}")
6570
self.assertTrue(check_causal_lm_output_quality(model_id, generated_text))
71+
72+
@slow
73+
@pytest.mark.run_slow
74+
def test_smollm_text_generation_with_custom_sdpa(self):
75+
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
76+
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
77+
78+
model_id = "HuggingFaceTB/SmolLM2-135M"
79+
prompt = "My favourite condiment is "
80+
max_seq_len = 32
81+
tokenizer = AutoTokenizer.from_pretrained(model_id)
82+
83+
# ExecuTorch model + custom sdpa
84+
model = ExecuTorchModelForCausalLM.from_pretrained(
85+
model_id,
86+
recipe="xnnpack",
87+
attn_implementation="executorch_custom_sdpa",
88+
)
89+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
90+
self.assertIsInstance(model.model, ExecuTorchModule)
91+
92+
generated_text = model.text_generation(
93+
tokenizer=tokenizer,
94+
prompt=prompt,
95+
max_seq_len=max_seq_len,
96+
)
97+
logging.info(f"\nGenerated text:\n\t{generated_text}")
98+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_text))

0 commit comments

Comments
 (0)