Skip to content

Commit 2af1cd4

Browse files
author
Guang Yang
committed
Support lowering quantized checkpoint from Hub
1 parent da80c9e commit 2af1cd4

File tree

2 files changed

+113
-10
lines changed

2 files changed

+113
-10
lines changed

optimum/executorch/modeling.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
9292
f"This attribute is used to identify the corresponding AutoModel class."
9393
)
9494

95-
for key, value in models.items():
96-
setattr(self, key, value)
95+
if len(models) == 1:
96+
# For single PTE, always set the attr to "model"
97+
setattr(self, "model", next(iter(models.values())))
98+
else:
99+
for key, value in models.items():
100+
setattr(self, key, value)
97101

98102
self.stats = Stats()
99103

@@ -570,8 +574,8 @@ class ExecuTorchModelForCausalLM(ExecuTorchModelBase):
570574
Data type of the model parameters.
571575
bos_token_id (`int`):
572576
Beginning-of-sequence token ID.
573-
eos_token_id (`int`):
574-
End-of-sequence token ID.
577+
eos_token_ids (`List[int]`):
578+
End-of-sequence token IDs.
575579
vocab_size (`int`):
576580
Size of the model vocabulary.
577581
"""
@@ -594,8 +598,10 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
594598
self.dtype = self.model.run_method("get_dtype")[0]
595599
if "get_bos_id" in metadata:
596600
self.bos_token_id = self.model.run_method("get_bos_id")[0]
597-
if "get_eos_id" in metadata:
598-
self.eos_token_id = self.model.run_method("get_eos_id")[0]
601+
for key in ("get_eos_id", "get_eos_ids"):
602+
if key in metadata:
603+
self.eos_token_ids = self.model.run_method("get_eos_ids")
604+
break
599605
if "get_vocab_size" in metadata:
600606
self.vocab_size = self.model.run_method("get_vocab_size")[0]
601607
if "use_sdpa_with_kv_cache" in metadata:
@@ -694,7 +700,7 @@ def generate(
694700
next_token = torch.argmax(logits, dim=-1).item()
695701
generated_tokens.append(next_token)
696702

697-
if next_token == self.eos_token_id:
703+
if next_token in self.eos_token_ids:
698704
break
699705

700706
self.stats.set_num_generated_tokens(len(generated_tokens) - len(prompt_tokens))
@@ -730,9 +736,9 @@ def text_generation(
730736
raise ValueError(
731737
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
732738
)
733-
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id:
739+
if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id not in self.eos_token_ids:
734740
raise ValueError(
735-
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}."
741+
f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must match with the model's eos_token_ids={self.eos_token_ids}."
736742
)
737743

738744
# Reset stats for a new generation

tests/models/test_modeling_phi4.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515

1616
import gc
1717
import logging
18+
import os
1819
import unittest
1920

2021
import pytest
22+
import torchao
2123
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24+
from packaging.version import parse
2225
from transformers import AutoConfig, AutoTokenizer
2326
from transformers.testing_utils import slow
2427

@@ -27,13 +30,18 @@
2730
from ..utils import check_causal_lm_output_quality
2831

2932

30-
@pytest.mark.skip(reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM")
33+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
34+
35+
is_ci = os.environ.get("GITHUB_ACTIONS") == "true"
36+
37+
3138
class ExecuTorchModelIntegrationTest(unittest.TestCase):
3239
def __init__(self, *args, **kwargs):
3340
super().__init__(*args, **kwargs)
3441

3542
@slow
3643
@pytest.mark.run_slow
44+
@pytest.mark.skip(is_ci, reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM")
3745
def test_phi4_text_generation(self):
3846
model_id = "microsoft/Phi-4-mini-instruct"
3947
config = AutoConfig.from_pretrained(model_id)
@@ -61,3 +69,92 @@ def test_phi4_text_generation(self):
6169
gc.collect()
6270

6371
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))
72+
73+
@slow
74+
@pytest.mark.run_slow
75+
@pytest.mark.skipif(
76+
parse(torchao.__version__) < parse("0.11.0.dev0"),
77+
reason="Only available on torchao >= 0.11.0.dev0",
78+
)
79+
def test_phi4_text_generation_with_quantized_pte_from_hub(self):
80+
model_id = "pytorch/Phi-4-mini-instruct-8da4w"
81+
config = AutoConfig.from_pretrained(model_id)
82+
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
83+
# the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite
84+
# that function to avoid the data-dependent control flow.
85+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
86+
config.rope_scaling["type"] = "default"
87+
model = ExecuTorchModelForCausalLM.from_pretrained(
88+
model_id, recipe="xnnpack", config=config, file_name="phi4-mini-8da4w.pte"
89+
)
90+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
91+
self.assertIsInstance(model.model, ExecuTorchModule)
92+
93+
tokenizer = AutoTokenizer.from_pretrained(model_id)
94+
generated_text = model.text_generation(
95+
tokenizer=tokenizer,
96+
prompt="My favourite condiment is ",
97+
max_seq_len=64,
98+
)
99+
logging.info(f"\nGenerated text:\n\t{generated_text}")
100+
101+
if not is_ci:
102+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
103+
104+
# Free memory before loading eager for quality check
105+
del model
106+
del tokenizer
107+
gc.collect()
108+
109+
self.assertTrue(
110+
check_causal_lm_output_quality(
111+
"microsoft/Phi-4-mini-instruct",
112+
generated_tokens,
113+
)
114+
)
115+
116+
@slow
117+
@pytest.mark.run_slow
118+
@pytest.mark.skipif(
119+
parse(torchao.__version__) < parse("0.11.0.dev0"),
120+
reason="Only available on torchao >= 0.11.0.dev0",
121+
)
122+
def test_phi4_text_generation_with_quantized_ckp(self):
123+
model_id = "pytorch/Phi-4-mini-instruct-8da4w"
124+
config = AutoConfig.from_pretrained(model_id)
125+
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
126+
# the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite
127+
# that function to avoid the data-dependent control flow.
128+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
129+
config.rope_scaling["type"] = "default"
130+
model = ExecuTorchModelForCausalLM.from_pretrained(
131+
model_id,
132+
recipe="xnnpack",
133+
config=config,
134+
export=True,
135+
)
136+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
137+
self.assertIsInstance(model.model, ExecuTorchModule)
138+
139+
tokenizer = AutoTokenizer.from_pretrained(model_id)
140+
generated_text = model.text_generation(
141+
tokenizer=tokenizer,
142+
prompt="My favourite condiment is ",
143+
max_seq_len=64,
144+
)
145+
logging.info(f"\nGenerated text:\n\t{generated_text}")
146+
147+
if not is_ci:
148+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
149+
150+
# Free memory before loading eager for quality check
151+
del model
152+
del tokenizer
153+
gc.collect()
154+
155+
self.assertTrue(
156+
check_causal_lm_output_quality(
157+
"microsoft/Phi-4-mini-instruct",
158+
generated_tokens,
159+
)
160+
)

0 commit comments

Comments
 (0)