Skip to content

Commit 76d43bc

Browse files
Qualcomm AI Engine Direct - GA Static Granite3.3-2b (#15808)
### Summary Add Granite3.3-2b support. Source model: <img width="957" height="1047" alt="image" src="https://github.com/user-attachments/assets/d17dd15c-ffc1-43e9-9e57-7794a63d8a5d" /> <img width="1734" height="947" alt="image" src="https://github.com/user-attachments/assets/45f1e80d-95e7-4865-9dfc-1e04d3eb90e4" /> Static llama: `python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s c3b39f15 -m SM8650 --temperature 0 --model_mode kv --max_seq_len 1024 --prefill_ar_len 128 --decoder_model granite_3_3-2b_instruct --prompt "I would like to learn python, could you teach me with a simple example?" --run_lm_eval --task hellaswag --limit 10 --artifact llama_qnn --kv_updater shift_pointer` #### Accuracy(hellaswag)(limit=10) prepare_pt2e: {'acc_norm,none': 0.5} convert_pt2e: {'acc_norm,none': 0.3} device: {'acc_norm,none': 0.2} #### Statistics on SM8650(16a4w_block64) <img width="1167" height="395" alt="image" src="https://github.com/user-attachments/assets/42fcd93f-546f-4884-9540-07a89729acb2" /> #### Statistics on SM8750(16a4w_block64) <img width="1313" height="485" alt="image" src="https://github.com/user-attachments/assets/7ab11855-60a8-4b09-9e20-1f93069ccc11" /> ### Test plan ` python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_granite_3_3_2b_instruct --device c3b39f15 --host mlgtw-linux --model SM8650 --build_folder build-android --executorch_root . --artifact_dir ./llama_qnn --llama_artifacts llama_qnn ` cc @cccclai @cbilgin
1 parent 0c0cee5 commit 76d43bc

File tree

15 files changed

+354
-45
lines changed

15 files changed

+354
-45
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5933,7 +5933,7 @@ def test_static_llm_model(self):
59335933
"kv",
59345934
"--max_seq_len",
59355935
"1024",
5936-
"--eval_perplexity",
5936+
"--run_lm_eval",
59375937
"--tasks",
59385938
"wikitext",
59395939
"--limit",
@@ -6051,6 +6051,73 @@ def test_codegen2_1b(self):
60516051
if not self.compile_only and not self.enable_x86_64:
60526052
self.assertGreaterEqual(msg["inference_speed"], 60)
60536053

6054+
def test_granite_3_3_2b_instruct(self):
6055+
if not self.required_envs():
6056+
self.skipTest("missing required envs")
6057+
6058+
prompt = "What is the meaning of life?"
6059+
cmds = [
6060+
"python",
6061+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
6062+
"--artifact",
6063+
self.artifact_dir,
6064+
"--build_folder",
6065+
self.build_folder,
6066+
"--model",
6067+
self.model,
6068+
"--ip",
6069+
self.ip,
6070+
"--port",
6071+
str(self.port),
6072+
"--prompt",
6073+
f"{prompt}",
6074+
"--temperature",
6075+
"0",
6076+
"--decoder_model",
6077+
"granite_3_3-2b_instruct",
6078+
"--model_mode",
6079+
"kv",
6080+
"--max_seq_len",
6081+
"1024",
6082+
"--run_lm_eval",
6083+
"--tasks",
6084+
"hellaswag",
6085+
"--limit",
6086+
"10",
6087+
"--kv_updater",
6088+
"shift_pointer",
6089+
]
6090+
if self.compile_only:
6091+
cmds.extend(["--compile_only"])
6092+
elif self.device:
6093+
cmds.extend(["--device", self.device])
6094+
if self.host:
6095+
cmds.extend(["--host", self.host])
6096+
elif self.enable_x86_64:
6097+
cmds.extend(["--enable_x86_64"])
6098+
if self.pre_gen_pte:
6099+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
6100+
6101+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
6102+
with Listener((self.ip, self.port)) as listener:
6103+
conn = listener.accept()
6104+
p.communicate()
6105+
msg = json.loads(conn.recv())
6106+
if "Error" in msg:
6107+
self.fail(msg["Error"])
6108+
else:
6109+
inference_speed_ref = {"SM8650": 20, "SM8750": 22}
6110+
if (
6111+
not self.compile_only
6112+
and not self.enable_x86_64
6113+
and self.model in inference_speed_ref
6114+
):
6115+
self.assertLessEqual(msg["pte_size"], 1_600_000_000)
6116+
self.assertGreaterEqual(msg["acc_norm"], 0.2)
6117+
self.assertGreaterEqual(
6118+
msg["inference_speed"], inference_speed_ref[self.model]
6119+
)
6120+
60546121
def test_llama_stories_260k(self):
60556122
if not self.required_envs():
60566123
self.skipTest("missing required envs")
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.granite.convert_weights import convert_weights
5+
from executorch.examples.models.llama.model import Llama2Model
6+
7+
8+
class GraniteModel(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"GraniteModel",
15+
"convert_weights",
16+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"dim": 2048,
3+
"attention_qkv_bias": false,
4+
"attention_multiplier": 0.015625,
5+
"bos_idx": 0,
6+
"embedding_scale_factor": 12.0,
7+
"eos_idx": 0,
8+
"act_fn": "silu",
9+
"hidden_dim": 8192,
10+
"n_heads": 32,
11+
"n_layers": 40,
12+
"n_kv_heads": 8,
13+
"norm_eps": 1e-05,
14+
"rope_theta": 10000000.0,
15+
"vocab_size": 49159,
16+
"use_hf_rope": false,
17+
"residual_multiplier": 0.22,
18+
"logits_scaling": 8.0
19+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import argparse
2+
3+
import json
4+
import os
5+
from typing import Dict
6+
7+
import torch
8+
from safetensors.torch import load_file
9+
10+
from torchtune.models.convert_weights import get_mapped_key
11+
12+
13+
# Weight mappings from Granite 3's checkpoint to ExecuTorch's transformer parameters.
14+
_GRANITE_TO_EXECUTORCH = {
15+
"model.embed_tokens.weight": "tok_embeddings.weight",
16+
"model.norm.weight": "norm.weight",
17+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
18+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
19+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
20+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
21+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
22+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
23+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
24+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
25+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
26+
}
27+
28+
29+
def granite_to_executorch(
30+
state_dict: Dict[str, torch.Tensor]
31+
) -> Dict[str, torch.Tensor]:
32+
"""
33+
Convert the state dict so that it matches what ExecuTorch's transformer definition expects.
34+
"""
35+
converted_state_dict = {}
36+
for key, value in state_dict.items():
37+
new_key = get_mapped_key(key, _GRANITE_TO_EXECUTORCH)
38+
converted_state_dict[new_key] = value
39+
converted_state_dict["output.weight"] = converted_state_dict[
40+
"tok_embeddings.weight"
41+
]
42+
return converted_state_dict
43+
44+
45+
def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
46+
index_path = os.path.join(input_dir, "model.safetensors.index.json")
47+
if os.path.exists(index_path):
48+
# Sharded checkpoint.
49+
with open(index_path, "r") as f:
50+
index = json.load(f)
51+
weight_map = index["weight_map"]
52+
checkpoint_shards = sorted(set(weight_map.values()))
53+
54+
# Load all the shards into memory
55+
shard_to_weights = {}
56+
for shard in checkpoint_shards:
57+
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))
58+
59+
# Merge tensors into consolidated state dict.
60+
merged_state_dict = {}
61+
for weight_name, shard in weight_map.items():
62+
tensor = shard_to_weights[shard][weight_name]
63+
merged_state_dict[weight_name] = tensor
64+
return merged_state_dict
65+
else:
66+
# Single checkpoint.
67+
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
68+
return state_dict
69+
70+
71+
def load_checkpoint(input_dir: str) -> Dict:
72+
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
73+
if os.path.exists(pytorch_path):
74+
print("Loading checkpoint from PyTorch .bin file")
75+
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
76+
print("Loading checkpoint from safetensors directory")
77+
return load_checkpoint_from_safetensors(input_dir)
78+
79+
80+
def convert_weights(input_dir: str, output_file: str) -> None:
81+
print("Loading checkpoint...")
82+
sd = load_checkpoint(input_dir)
83+
print("Converting checkpoint...")
84+
sd = granite_to_executorch(sd)
85+
print("Saving checkpoint...")
86+
torch.save(sd, output_file)
87+
print("Done.")
88+
89+
90+
def main():
91+
parser = argparse.ArgumentParser(
92+
description="Convert Granite weights to ExecuTorch transformer format."
93+
)
94+
parser.add_argument(
95+
"input_dir",
96+
type=str,
97+
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
98+
)
99+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
100+
101+
args = parser.parse_args()
102+
convert_weights(args.input_dir, args.output)
103+
104+
105+
if __name__ == "__main__":
106+
main()

examples/models/llama/evaluate/eager_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def device(self):
6969
def tok_encode(self, string: str, **kwargs): # pyre-ignore
7070
return self._tokenizer.encode(string, bos=False, eos=False)
7171

72-
def tok_decode(self, tokens):
73-
return self._tokenizer.decode(tokens)
72+
def tok_decode(self, tokens, **kwargs):
73+
return self._tokenizer.decode([tokens] if isinstance(tokens, int) else tokens)
7474

7575
def _model_call(self, inps):
7676
if self._use_kv_cache:

examples/models/llama/model_args.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class ModelArgs:
4949
model_architecture: str = (
5050
"LlamaForCausalLM" # This setting is currently only supported for the QNN backend
5151
)
52+
attention_multiplier: Optional[float] = (
53+
None # Scaling factor 1/sqrt(d_k) in attention formula
54+
)
5255
norm_eps: float = 1e-5
5356
post_attention_norm: bool = False
5457
post_ffn_norm: bool = False
@@ -75,6 +78,9 @@ class ModelArgs:
7578
# at runtime. Enable it only necessary (e.g., use perplexity tools that requires
7679
# logits for all input tokens.)
7780
generate_full_logits: bool = False
81+
logits_scaling: Optional[float] = (
82+
None # Scaling factor applied to the logits of model, functioning similarly to a temperature parameter.
83+
)
7884
enable_dynamic_shape: bool = False # export model with dynamic shape support
7985
# A dictionary mapping from pruned token-id to original token-id
8086
input_prune_map: Optional[Dict[int, int]] = None
@@ -85,6 +91,9 @@ class ModelArgs:
8591
apply_output: bool = True # Use output layer (unembedding) inside the transformer
8692
use_qk_norm: bool = False # apply normalization to q and k in the attention
8793
qk_norm_before_rope: bool = False # when to apply qk norm
94+
residual_multiplier: Optional[float] = (
95+
None # Scaling factor applied to the residual hidden states
96+
)
8897
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
8998
no_rope_layer_interval: Optional[int] = (
9099
None # Interval at which to skip RoPE. From Rope to Nope and Back Again: A New Hybrid Attention Strategy (https://huggingface.co/papers/2501.18795).

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
## Overview
44
This file provides you the instructions to run LLM Decoder model with different parameters via Qualcomm HTP backend. We currently support the following models:
5+
<!-- numbered list will be automatically generated -->
56
1. LLAMA2 Stories 110M
6-
2. LLAMA3.2 1B
7-
3. LLAMA3.2 3B
8-
4. Codegen2 1B
9-
5. Gemma 2B
10-
6. Gemma3 1B
11-
7. Phi4-mini-instruct
12-
8. QWEN2.5 0.5B / 1.5B
13-
9. QWEN3 0.6B / 1.7B
14-
10. SmolLM2 135M
15-
11. SmolLM3 3B
7+
1. LLAMA3.2 1B
8+
1. LLAMA3.2 3B
9+
1. Codegen2 1B
10+
1. Gemma 2B
11+
1. Gemma3 1B
12+
1. Granite3.3 2B
13+
1. Phi4-mini-instruct
14+
1. QWEN2.5 0.5B / 1.5B
15+
1. QWEN3 0.6B / 1.7B
16+
1. SmolLM2 135M
17+
1. SmolLM3 3B
1618

1719

1820
We offer the following modes to execute the model:
@@ -100,6 +102,12 @@ Default example using hybrid mode
100102
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
101103
```
102104

105+
#### Granite3.3 2B
106+
Default example using hybrid mode
107+
```bash
108+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model granite_3_3-2b_instruct --prompt "I would like to learn python, could you teach me with a simple example?" --run_lm_eval --task hellaswag --limit 10
109+
```
110+
103111
#### Phi4-mini-instruct
104112
Default example using kv mode.
105113
```bash
@@ -227,24 +235,24 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
227235
#### Perplexity Evaluation
228236
This script supports perplexity evaluation and is capable of assessing perplexity scores across 3 phases: prepare_pt2e(CPU FP), convert_pt2e(CPU QDQ), QNN on device.
229237

230-
To evaluate the perplexity across all 3 phases, users should provide the `--eval_perplexity` flag and specify the evaluation task. Please notice when this flag is provided, the `--prompt ${PROMPT}` will be ignored.
238+
To evaluate the perplexity across all 3 phases, users should provide the `--run_lm_eval` flag and specify the evaluation task. Please notice when this flag is provided, the `--prompt ${PROMPT}` will be ignored.
231239

232240
For example, using the Qwen model and 1 wikitext sample as the evaluation task, users can assess all 3 phases perplexity score in a single run by including the appropriate configuration:
233241
```bash
234-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 1
242+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --run_lm_eval --tasks wikitext --limit 1
235243
```
236244

237245
For the example script above, 1 wikitext sample is used to evaluate all 3 phases. However, there are cases where a user may want to use one sample for quantization calibration and multiple samples for perplexity evaluation. In this case, the process should be split into two runs. In the 1st run, the model is compiled using one sample. In the 2nd run, the user can provide a different configuration for QNN device execution.
238246
Example:
239247
```bash
240248
# 1st run to compile with --limit 1
241-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 1 --compile_only
249+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --run_lm_eval --tasks wikitext --limit 1 --compile_only
242250
```
243251
```bash
244252
# 2nd run to perform QNN device execution with --limit 3
245-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json
253+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --run_lm_eval --tasks wikitext --limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json
246254
```
247255

248256
#### Tasks quantization calibration
249257
If `--tasks ${TASK}` is not provided, the program will use `--prompt ${PROMPT}` as the dataset for quantization calibration.
250-
Regardless of whether `--eval_perplexity` is provided, as long as `--tasks ${TASK}` is specified, the specified tasks will be used for model quantization calibration instead of the prompt.
258+
Regardless of whether `--run_lm_eval` is provided, as long as `--tasks ${TASK}` is specified, the specified tasks will be used for model quantization calibration instead of the prompt.

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929

3030
from executorch.examples.models.gemma import convert_weights as convert_gemma_weights
3131
from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights
32+
from executorch.examples.models.granite import (
33+
convert_weights as convert_granite_weights,
34+
)
3235
from executorch.examples.models.phi_4_mini import (
3336
convert_weights as convert_phi_4_mini_weights,
3437
)
@@ -385,6 +388,35 @@ class Gemma3(LLMModelConfig):
385388
)
386389

387390

391+
@register_llm_model("granite_3_3-2b_instruct")
392+
@dataclass(init=False, frozen=True)
393+
class Granite_3_3_2b_Instruct(LLMModelConfig):
394+
repo_id: str = "ibm-granite/granite-3.3-2b-instruct"
395+
params_path: str = os.path.join(
396+
BASE_DIR, "../../../models/granite/config/2b_config.json"
397+
)
398+
convert_weights = convert_granite_weights
399+
transform_weight = False
400+
instruct_model = True
401+
402+
num_sharding = 1
403+
# quant config
404+
ptq = QuantDtype.use_16a4w_block
405+
group_size = 64
406+
masked_softmax = True
407+
seq_mse_candidates = 0
408+
r1 = False
409+
r2 = False
410+
r3 = False
411+
quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config(
412+
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
413+
)
414+
custom_annotation = (
415+
annotate_kv_8bit,
416+
partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w),
417+
)
418+
419+
388420
@register_llm_model("phi_4_mini")
389421
@dataclass(init=False, frozen=True)
390422
class Phi4Mini(LLMModelConfig):

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"stories110m": "llama2",
1717
"gemma-2b": "gemma",
1818
"gemma3-1b": "gemma3",
19+
"granite_3_3-2b_instruct": "granite",
1920
"phi_4_mini": "phi_4_mini",
2021
"llama3_2-1b_instruct": "llama3",
2122
"llama3_2-3b_instruct": "llama3",

0 commit comments

Comments
 (0)