|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import logging |
16 | 17 | import os |
17 | 18 | import subprocess |
18 | 19 | import tempfile |
19 | 20 | import unittest |
20 | 21 | from pathlib import Path |
21 | 22 | from tempfile import TemporaryDirectory |
22 | 23 |
|
| 24 | +import torch |
| 25 | +from executorch import version as executorch_version |
23 | 26 | from executorch.extension.pybindings.portable_lib import ExecuTorchModule |
24 | 27 | from huggingface_hub import HfApi |
| 28 | +from packaging import version as pkg_version |
| 29 | +from transformers import ( |
| 30 | + AutoModelForCausalLM, |
| 31 | + AutoTokenizer, |
| 32 | + GenerationConfig, |
| 33 | +) |
25 | 34 |
|
26 | 35 | from optimum.executorch import ExecuTorchModelForCausalLM |
27 | 36 | from optimum.executorch.modeling import _FILE_PATTERN |
28 | 37 | from optimum.exporters.executorch import main_export |
29 | 38 | from optimum.utils.file_utils import find_files_matching_pattern |
30 | 39 |
|
| 40 | +from ..utils import check_causal_lm_output_quality |
| 41 | + |
| 42 | + |
| 43 | +os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| 44 | + |
31 | 45 |
|
32 | 46 | class ExecuTorchModelIntegrationTest(unittest.TestCase): |
33 | 47 | def __init__(self, *args, **kwargs): |
@@ -97,3 +111,53 @@ def test_find_files_matching_pattern(self): |
97 | 111 | api.snapshot_download(repo_id=model_id, local_dir=local_dir, revision=revision) |
98 | 112 | pte_files = find_files_matching_pattern(local_dir, pattern=_FILE_PATTERN, revision=revision) |
99 | 113 | self.assertTrue(len(pte_files) == 0 if revision == "main" else len(pte_files) > 0) |
| 114 | + |
| 115 | + def test_export_with_custom_sdpa(self): |
| 116 | + if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"): |
| 117 | + self.skipTest(reason="This test requires executorch >= 0.6 to run.") |
| 118 | + |
| 119 | + model_id = "optimum-internal-testing/tiny-random-llama" |
| 120 | + with tempfile.TemporaryDirectory() as tempdir: |
| 121 | + subprocess.run( |
| 122 | + f"optimum-cli export executorch \ |
| 123 | + --model {model_id} \ |
| 124 | + --task 'text-generation' \ |
| 125 | + --recipe 'xnnpack' \ |
| 126 | + --output_dir {tempdir}/executorch", |
| 127 | + shell=True, |
| 128 | + check=True, |
| 129 | + ) |
| 130 | + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) |
| 131 | + |
| 132 | + def test_eager_text_generation_with_custom_sdpa(self): |
| 133 | + if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"): |
| 134 | + self.skipTest(reason="This test requires executorch >= 0.6 to run.") |
| 135 | + |
| 136 | + model_id = "HuggingFaceTB/SmolLM2-135M" |
| 137 | + prompt = "My favourite condiment is " |
| 138 | + max_seq_len = 32 |
| 139 | + tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 140 | + |
| 141 | + # Eager model + custom sdpa |
| 142 | + cache_implementation = "static" |
| 143 | + attn_implementation = "custom_sdpa" |
| 144 | + eager_model = AutoModelForCausalLM.from_pretrained( |
| 145 | + model_id, |
| 146 | + torch_dtype=torch.bfloat16, |
| 147 | + attn_implementation=attn_implementation, |
| 148 | + generation_config=GenerationConfig( |
| 149 | + use_cache=True, |
| 150 | + cache_implementation=cache_implementation, |
| 151 | + max_length=max_seq_len, |
| 152 | + cache_config={ |
| 153 | + "batch_size": 1, |
| 154 | + "max_cache_len": max_seq_len, |
| 155 | + }, |
| 156 | + ), |
| 157 | + ) |
| 158 | + self.assertTrue(eager_model.config._attn_implementation, attn_implementation) |
| 159 | + eager_inputs = tokenizer(prompt, return_tensors="pt").to(eager_model.device) |
| 160 | + eager_generated_ids = eager_model.generate(**eager_inputs, max_new_tokens=max_seq_len, temperature=0) |
| 161 | + eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0] |
| 162 | + logging.info(f"\nEager generated text:\n\t{eager_generated_text}") |
| 163 | + self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_ids)) |
0 commit comments