Skip to content

Commit 2901511

Browse files
guangy10Guang Yang
andauthored
Use custom SDPA for decoder-only HF Transformers (#46)
* Use custom sdpa for ExecuTorch * support export with custom_sdpa using optimum-cli * Updated docs to reflect using custom_sdpa --------- Co-authored-by: Guang Yang <[email protected]>
1 parent f92847e commit 2901511

File tree

14 files changed

+348
-22
lines changed

14 files changed

+348
-22
lines changed

README.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,19 @@ from optimum.executorch import ExecuTorchModelForCausalLM
7777
from transformers import AutoTokenizer
7878

7979
# Load and export the model on-the-fly
80-
model_id = "meta-llama/Llama-3.2-1B"
81-
model = ExecuTorchModelForCausalLM.from_pretrained(model_id, recipe="xnnpack")
80+
model_id = "HuggingFaceTB/SmolLM2-135M"
81+
model = ExecuTorchModelForCausalLM.from_pretrained(
82+
model_id,
83+
recipe="xnnpack",
84+
attn_implementation="custom_sdpa", # Use custom SDPA implementation for better performance
85+
)
8286

8387
# Generate text right away
8488
tokenizer = AutoTokenizer.from_pretrained(model_id)
8589
generated_text = model.text_generation(
8690
tokenizer=tokenizer,
8791
prompt="Simply put, the theory of relativity states that",
88-
max_seq_len=128
92+
max_seq_len=32,
8993
)
9094
print(generated_text)
9195
```
@@ -99,10 +103,11 @@ print(generated_text)
99103
Use the CLI tool to convert your model to ExecuTorch format:
100104
```
101105
optimum-cli export executorch \
102-
--model "meta-llama/Llama-3.2-1B" \
106+
--model "HuggingFaceTB/SmolLM2-135M" \
103107
--task "text-generation" \
104108
--recipe "xnnpack" \
105-
--output_dir="meta_llama3_2_1b"
109+
--output_dir="hf_smollm2" \
110+
--use_custom_sdpa
106111
```
107112

108113
#### Step 2: Load and run inference
@@ -112,14 +117,14 @@ from optimum.executorch import ExecuTorchModelForCausalLM
112117
from transformers import AutoTokenizer
113118

114119
# Load the exported model
115-
model = ExecuTorchModelForCausalLM.from_pretrained("./meta_llama3_2_1b")
120+
model = ExecuTorchModelForCausalLM.from_pretrained("./hf_smollm2")
116121

117122
# Initialize tokenizer and generate text
118-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
123+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
119124
generated_text = model.text_generation(
120125
tokenizer=tokenizer,
121126
prompt="Simply put, the theory of relativity states that",
122-
max_seq_len=128
127+
max_seq_len=32
123128
)
124129
print(generated_text)
125130
```

docs/source/guides/export.mdx

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ performance.
3535
Exporting a PyTorch model to ExecuTorch is as simple as
3636

3737
```bash
38-
optimum-cli export executorch --model meta-llama/Llama-3.2-1B --task text-generation --recipe xnnpack --output_dir meta_llama3_2_1b
38+
optimum-cli export executorch \
39+
--model HuggingFaceTB/SmolLM2-135M \
40+
--task text-generation \
41+
--recipe xnnpack \
42+
--output_dir hf_smollm2 \
43+
--use_custom_sdpa
3944
```
4045

4146
Check out the help for more options:
@@ -68,13 +73,14 @@ Required arguments:
6873
classification', 'text2text-generation', 'token-classification', 'object-detection', 'zero-shot-object-detection', 'zero-shot-image-classification', 'text-
6974
generation', 'fill-mask'].
7075
--recipe RECIPE Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".
76+
--use_custom_sdpa For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.
7177

7278
```
7379

74-
You should see a `model.pte` file is stored under "./meta_llama3_2_1b/":
80+
You should see a `model.pte` file is stored under "./hf_smollm2/":
7581

7682
```bash
77-
meta_llama3_2_1b/
83+
hf_smollm2/
7884
└── model.pte
7985
```
8086

@@ -87,16 +93,10 @@ For example, we can load and run the model with [ExecuTorch Runtime](https://pyt
8793
from transformers import AutoTokenizer
8894
from optimum.executorch import ExecuTorchModelForCausalLM
8995

90-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
91-
model = ExecuTorchModelForCausalLM.from_pretrained("meta_llama3_2_1b/")
96+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
97+
model = ExecuTorchModelForCausalLM.from_pretrained("hf_smollm2/")
9298
prompt = "Simply put, the theory of relativity states that"
93-
generated_text = model.text_generation(tokenizer=tokenizer, prompt=prompt, max_seq_len=45)
94-
```
95-
96-
Printing the `generated_text` would give that:
97-
98-
```
99-
"Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference. In other words, the laws of physics are the same in all inertial frames of reference."
99+
print(f"\nGenerated texts:\n\t{model.text_generation(tokenizer=tokenizer, prompt=prompt, max_seq_len=45)}")
100100
```
101101

102102
As you can see, converting a model to ExecuTorch does not mean leaving the Hugging Face ecosystem. You end up with a similar API as regular 🤗 Transformers models!
@@ -106,5 +106,5 @@ In case your model wasn't already exported to ExecuTorch, it can also be convert
106106
```python
107107
from optimum.executorch import ExecuTorchModelForCausalLM
108108

109-
model = ExecuTorchModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", recipe="xnnpack")
109+
model = ExecuTorchModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", recipe="xnnpack", attn_implementation="custom_sdpa")
110110
```

optimum/commands/export/executorch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from pathlib import Path
1818
from typing import TYPE_CHECKING
1919

20+
from executorch import version as executorch_version
21+
from packaging import version as pkg_version
22+
2023
from ...exporters import TasksManager
2124
from ..base import BaseOptimumCLICommand, CommandInfo
2225

@@ -51,6 +54,12 @@ def parse_args_executorch(parser):
5154
default="xnnpack",
5255
help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".',
5356
)
57+
required_group.add_argument(
58+
"--use_custom_sdpa",
59+
required=False,
60+
action="store_true",
61+
help="For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.",
62+
)
5463

5564

5665
class ExecuTorchExportCommand(BaseOptimumCLICommand):
@@ -63,9 +72,16 @@ def parse_args(parser: "ArgumentParser"):
6372
def run(self):
6473
from ...exporters.executorch import main_export
6574

75+
kwargs = {}
76+
if self.args.use_custom_sdpa:
77+
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
78+
raise ValueError("custom_sdpa is not supported for executorch < 0.6.0")
79+
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
80+
6681
main_export(
6782
model_name_or_path=self.args.model,
6883
task=self.args.task,
6984
recipe=self.args.recipe,
7085
output_dir=self.args.output_dir,
86+
**kwargs,
7187
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Tuple, Union
16+
17+
import torch
18+
from executorch import version as executorch_version
19+
from packaging import version as pkg_version
20+
21+
22+
if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
23+
from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa
24+
25+
def custom_sdpa_with_start_pos_forward(
26+
module: torch.nn.Module,
27+
query: torch.Tensor,
28+
key: torch.Tensor,
29+
value: torch.Tensor,
30+
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
31+
scaling: Optional[float] = None,
32+
softcap: Optional[float] = None,
33+
head_mask: Optional[torch.Tensor] = None,
34+
**kwargs,
35+
) -> Tuple[torch.Tensor, None]:
36+
# This is before the transpose
37+
max_seq_len = key.shape[2]
38+
39+
# FA2 uses non-transposed inputs
40+
query = query.transpose(1, 2)
41+
key = key.transpose(1, 2)
42+
value = value.transpose(1, 2)
43+
44+
# Convert the hell out of the inputs to fp32 and back
45+
input_dtype = query.dtype
46+
query = query.to(torch.float32)
47+
key = key.to(torch.float32)
48+
value = value.to(torch.float32)
49+
50+
# Ignore the causal flag from kwargs but use the one in module
51+
kwargs.pop("is_causal", None)
52+
53+
# Calculate the input pos from attention mask.
54+
# Branch out for float vs bool mask
55+
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
56+
attention_mask = attention_mask.reshape(-1, max_seq_len)
57+
first_row_mask = attention_mask[0, :]
58+
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
59+
start_pos = torch.argmin(first_row_mask).item() - 1
60+
output = torch.ops.llama.custom_sdpa(
61+
query,
62+
key,
63+
value,
64+
start_pos=start_pos,
65+
attn_mask=None,
66+
drpout_p=0.0,
67+
is_causal=module.is_causal,
68+
scale=scaling,
69+
)
70+
return output.to(input_dtype), None

optimum/executorch/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,8 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
595595
self.eos_token_id = self.model.run_method("get_eos_id")[0]
596596
if "get_vocab_size" in metadata:
597597
self.vocab_size = self.model.run_method("get_vocab_size")[0]
598+
if "use_sdpa_with_kv_cache" in metadata:
599+
self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0]
598600

599601
def forward(
600602
self,

optimum/exporters/executorch/convert.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,21 @@
1919
from pathlib import Path
2020
from typing import Union
2121

22+
from packaging import version as pkg_version
23+
from transformers.modeling_utils import AttentionInterface
24+
25+
from executorch import version as executorch_version
26+
2227
from .recipe_registry import discover_recipes, recipe_registry
2328

2429

30+
if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"):
31+
from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward
32+
33+
# Register custom sdpa via `AttentionInterface` for executorch>=0.6.0
34+
AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward)
35+
36+
2537
logger = logging.getLogger(__name__)
2638

2739

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,14 @@ def _lower_to_executorch(
7777
return et_progs
7878

7979
exported_progs = model.export()
80+
81+
if model.config._attn_implementation == "custom_sdpa":
82+
# Sanity check to make sure the exported program contains the custom sdpa operator.
83+
if not any(
84+
node.op == "call_function" and "custom_sdpa" in str(node.target)
85+
for exported_program in exported_progs.values()
86+
for node in exported_program.graph_module.graph.nodes
87+
):
88+
raise ValueError("'custom_sdpa' not found in the graph.")
89+
8090
return _lower_to_executorch(exported_progs, model.metadata)

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
4949
device = "cpu"
5050
batch_size = 1
5151
dtype = kwargs.get("dtype", "float32")
52-
attn_implementation = kwargs.get("attn_implementation", "sdpa")
52+
use_custom_sdpa = kwargs.get("use_custom_sdpa", False)
53+
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
5354
cache_implementation = kwargs.get("cache_implementation", "static")
5455
max_length = kwargs.get("max_length", 2048)
5556
config = kwargs.get("config", None)

optimum/exporters/executorch/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def save_config_to_constant_methods(
4343
"get_max_batch_size": 1,
4444
"get_max_seq_len": getattr(config, "max_position_embeddings", None),
4545
"decoder_start_token_id": getattr(config, "decoder_start_token_id", None),
46+
"use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation,
4647
}
4748

4849
# Safely access fields from generation_config if it exists

tests/models/test_modeling_common.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,35 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
1617
import os
1718
import subprocess
1819
import tempfile
1920
import unittest
2021
from pathlib import Path
2122
from tempfile import TemporaryDirectory
2223

24+
import torch
25+
from executorch import version as executorch_version
2326
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
2427
from huggingface_hub import HfApi
28+
from packaging import version as pkg_version
29+
from transformers import (
30+
AutoModelForCausalLM,
31+
AutoTokenizer,
32+
GenerationConfig,
33+
)
2534

2635
from optimum.executorch import ExecuTorchModelForCausalLM
2736
from optimum.executorch.modeling import _FILE_PATTERN
2837
from optimum.exporters.executorch import main_export
2938
from optimum.utils.file_utils import find_files_matching_pattern
3039

40+
from ..utils import check_causal_lm_output_quality
41+
42+
43+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
44+
3145

3246
class ExecuTorchModelIntegrationTest(unittest.TestCase):
3347
def __init__(self, *args, **kwargs):
@@ -97,3 +111,53 @@ def test_find_files_matching_pattern(self):
97111
api.snapshot_download(repo_id=model_id, local_dir=local_dir, revision=revision)
98112
pte_files = find_files_matching_pattern(local_dir, pattern=_FILE_PATTERN, revision=revision)
99113
self.assertTrue(len(pte_files) == 0 if revision == "main" else len(pte_files) > 0)
114+
115+
def test_export_with_custom_sdpa(self):
116+
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
117+
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
118+
119+
model_id = "optimum-internal-testing/tiny-random-llama"
120+
with tempfile.TemporaryDirectory() as tempdir:
121+
subprocess.run(
122+
f"optimum-cli export executorch \
123+
--model {model_id} \
124+
--task 'text-generation' \
125+
--recipe 'xnnpack' \
126+
--output_dir {tempdir}/executorch",
127+
shell=True,
128+
check=True,
129+
)
130+
self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))
131+
132+
def test_eager_text_generation_with_custom_sdpa(self):
133+
if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"):
134+
self.skipTest(reason="This test requires executorch >= 0.6 to run.")
135+
136+
model_id = "HuggingFaceTB/SmolLM2-135M"
137+
prompt = "My favourite condiment is "
138+
max_seq_len = 32
139+
tokenizer = AutoTokenizer.from_pretrained(model_id)
140+
141+
# Eager model + custom sdpa
142+
cache_implementation = "static"
143+
attn_implementation = "custom_sdpa"
144+
eager_model = AutoModelForCausalLM.from_pretrained(
145+
model_id,
146+
torch_dtype=torch.bfloat16,
147+
attn_implementation=attn_implementation,
148+
generation_config=GenerationConfig(
149+
use_cache=True,
150+
cache_implementation=cache_implementation,
151+
max_length=max_seq_len,
152+
cache_config={
153+
"batch_size": 1,
154+
"max_cache_len": max_seq_len,
155+
},
156+
),
157+
)
158+
self.assertTrue(eager_model.config._attn_implementation, attn_implementation)
159+
eager_inputs = tokenizer(prompt, return_tensors="pt").to(eager_model.device)
160+
eager_generated_ids = eager_model.generate(**eager_inputs, max_new_tokens=max_seq_len, temperature=0)
161+
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0]
162+
logging.info(f"\nEager generated text:\n\t{eager_generated_text}")
163+
self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_ids))

0 commit comments

Comments
 (0)