diff --git a/README.md b/README.md index db7daff2..d3cf3563 100644 --- a/README.md +++ b/README.md @@ -77,15 +77,19 @@ from optimum.executorch import ExecuTorchModelForCausalLM from transformers import AutoTokenizer # Load and export the model on-the-fly -model_id = "meta-llama/Llama-3.2-1B" -model = ExecuTorchModelForCausalLM.from_pretrained(model_id, recipe="xnnpack") +model_id = "HuggingFaceTB/SmolLM2-135M" +model = ExecuTorchModelForCausalLM.from_pretrained( + model_id, + recipe="xnnpack", + attn_implementation="custom_sdpa", # Use custom SDPA implementation for better performance +) # Generate text right away tokenizer = AutoTokenizer.from_pretrained(model_id) generated_text = model.text_generation( tokenizer=tokenizer, prompt="Simply put, the theory of relativity states that", - max_seq_len=128 + max_seq_len=32, ) print(generated_text) ``` @@ -99,10 +103,11 @@ print(generated_text) Use the CLI tool to convert your model to ExecuTorch format: ``` optimum-cli export executorch \ - --model "meta-llama/Llama-3.2-1B" \ + --model "HuggingFaceTB/SmolLM2-135M" \ --task "text-generation" \ --recipe "xnnpack" \ - --output_dir="meta_llama3_2_1b" + --output_dir="hf_smollm2" \ + --use_custom_sdpa ``` #### Step 2: Load and run inference @@ -112,14 +117,14 @@ from optimum.executorch import ExecuTorchModelForCausalLM from transformers import AutoTokenizer # Load the exported model -model = ExecuTorchModelForCausalLM.from_pretrained("./meta_llama3_2_1b") +model = ExecuTorchModelForCausalLM.from_pretrained("./hf_smollm2") # Initialize tokenizer and generate text -tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") +tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") generated_text = model.text_generation( tokenizer=tokenizer, prompt="Simply put, the theory of relativity states that", - max_seq_len=128 + max_seq_len=32 ) print(generated_text) ``` diff --git a/docs/source/guides/export.mdx b/docs/source/guides/export.mdx index 5ac6e0eb..ff276cc3 100644 --- a/docs/source/guides/export.mdx +++ b/docs/source/guides/export.mdx @@ -35,7 +35,12 @@ performance. Exporting a PyTorch model to ExecuTorch is as simple as ```bash -optimum-cli export executorch --model meta-llama/Llama-3.2-1B --task text-generation --recipe xnnpack --output_dir meta_llama3_2_1b +optimum-cli export executorch \ + --model HuggingFaceTB/SmolLM2-135M \ + --task text-generation \ + --recipe xnnpack \ + --output_dir hf_smollm2 \ + --use_custom_sdpa ``` Check out the help for more options: @@ -68,13 +73,14 @@ Required arguments: classification', 'text2text-generation', 'token-classification', 'object-detection', 'zero-shot-object-detection', 'zero-shot-image-classification', 'text- generation', 'fill-mask']. --recipe RECIPE Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack". + --use_custom_sdpa For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False. ``` -You should see a `model.pte` file is stored under "./meta_llama3_2_1b/": +You should see a `model.pte` file is stored under "./hf_smollm2/": ```bash -meta_llama3_2_1b/ +hf_smollm2/ └── model.pte ``` @@ -87,16 +93,10 @@ For example, we can load and run the model with [ExecuTorch Runtime](https://pyt from transformers import AutoTokenizer from optimum.executorch import ExecuTorchModelForCausalLM -tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") -model = ExecuTorchModelForCausalLM.from_pretrained("meta_llama3_2_1b/") +tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") +model = ExecuTorchModelForCausalLM.from_pretrained("hf_smollm2/") prompt = "Simply put, the theory of relativity states that" -generated_text = model.text_generation(tokenizer=tokenizer, prompt=prompt, max_seq_len=45) -``` - -Printing the `generated_text` would give that: - -``` -"Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference. In other words, the laws of physics are the same in all inertial frames of reference." +print(f"\nGenerated texts:\n\t{model.text_generation(tokenizer=tokenizer, prompt=prompt, max_seq_len=45)}") ``` As you can see, converting a model to ExecuTorch does not mean leaving the Hugging Face ecosystem. You end up with a similar API as regular 🤗 Transformers models! @@ -106,5 +106,5 @@ In case your model wasn't already exported to ExecuTorch, it can also be convert ```python from optimum.executorch import ExecuTorchModelForCausalLM -model = ExecuTorchModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", recipe="xnnpack") +model = ExecuTorchModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", recipe="xnnpack", attn_implementation="custom_sdpa") ``` diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 62ca5b0a..7b777ab2 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -17,6 +17,9 @@ from pathlib import Path from typing import TYPE_CHECKING +from executorch import version as executorch_version +from packaging import version as pkg_version + from ...exporters import TasksManager from ..base import BaseOptimumCLICommand, CommandInfo @@ -51,6 +54,12 @@ def parse_args_executorch(parser): default="xnnpack", help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".', ) + required_group.add_argument( + "--use_custom_sdpa", + required=False, + action="store_true", + help="For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.", + ) class ExecuTorchExportCommand(BaseOptimumCLICommand): @@ -63,9 +72,16 @@ def parse_args(parser: "ArgumentParser"): def run(self): from ...exporters.executorch import main_export + kwargs = {} + if self.args.use_custom_sdpa: + if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"): + raise ValueError("custom_sdpa is not supported for executorch < 0.6.0") + kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa + main_export( model_name_or_path=self.args.model, task=self.args.task, recipe=self.args.recipe, output_dir=self.args.output_dir, + **kwargs, ) diff --git a/optimum/executorch/attentions/custom_sdpa.py b/optimum/executorch/attentions/custom_sdpa.py new file mode 100644 index 00000000..1f6310bf --- /dev/null +++ b/optimum/executorch/attentions/custom_sdpa.py @@ -0,0 +1,70 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +from executorch import version as executorch_version +from packaging import version as pkg_version + + +if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"): + from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa + + def custom_sdpa_with_start_pos_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, None]: + # This is before the transpose + max_seq_len = key.shape[2] + + # FA2 uses non-transposed inputs + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Convert the hell out of the inputs to fp32 and back + input_dtype = query.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + + # Ignore the causal flag from kwargs but use the one in module + kwargs.pop("is_causal", None) + + # Calculate the input pos from attention mask. + # Branch out for float vs bool mask + # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix." + attention_mask = attention_mask.reshape(-1, max_seq_len) + first_row_mask = attention_mask[0, :] + # [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3 + start_pos = torch.argmin(first_row_mask).item() - 1 + output = torch.ops.llama.custom_sdpa( + query, + key, + value, + start_pos=start_pos, + attn_mask=None, + drpout_p=0.0, + is_causal=module.is_causal, + scale=scaling, + ) + return output.to(input_dtype), None diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index e156ef65..619ad12b 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -595,6 +595,8 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon self.eos_token_id = self.model.run_method("get_eos_id")[0] if "get_vocab_size" in metadata: self.vocab_size = self.model.run_method("get_vocab_size")[0] + if "use_sdpa_with_kv_cache" in metadata: + self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0] def forward( self, diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py index ba16153f..cd646e32 100644 --- a/optimum/exporters/executorch/convert.py +++ b/optimum/exporters/executorch/convert.py @@ -19,9 +19,21 @@ from pathlib import Path from typing import Union +from packaging import version as pkg_version +from transformers.modeling_utils import AttentionInterface + +from executorch import version as executorch_version + from .recipe_registry import discover_recipes, recipe_registry +if pkg_version.parse(executorch_version.__version__) >= pkg_version.parse("0.6.0"): + from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward + + # Register custom sdpa via `AttentionInterface` for executorch>=0.6.0 + AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward) + + logger = logging.getLogger(__name__) diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py index 4511b7c7..e2949167 100644 --- a/optimum/exporters/executorch/recipes/xnnpack.py +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -77,4 +77,14 @@ def _lower_to_executorch( return et_progs exported_progs = model.export() + + if model.config._attn_implementation == "custom_sdpa": + # Sanity check to make sure the exported program contains the custom sdpa operator. + if not any( + node.op == "call_function" and "custom_sdpa" in str(node.target) + for exported_program in exported_progs.values() + for node in exported_program.graph_module.graph.nodes + ): + raise ValueError("'custom_sdpa' not found in the graph.") + return _lower_to_executorch(exported_progs, model.metadata) diff --git a/optimum/exporters/executorch/tasks/causal_lm.py b/optimum/exporters/executorch/tasks/causal_lm.py index fb91ffc0..b1f570c6 100644 --- a/optimum/exporters/executorch/tasks/causal_lm.py +++ b/optimum/exporters/executorch/tasks/causal_lm.py @@ -49,7 +49,8 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl device = "cpu" batch_size = 1 dtype = kwargs.get("dtype", "float32") - attn_implementation = kwargs.get("attn_implementation", "sdpa") + use_custom_sdpa = kwargs.get("use_custom_sdpa", False) + attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa") cache_implementation = kwargs.get("cache_implementation", "static") max_length = kwargs.get("max_length", 2048) config = kwargs.get("config", None) diff --git a/optimum/exporters/executorch/utils.py b/optimum/exporters/executorch/utils.py index 725ff02f..acc2ab3f 100644 --- a/optimum/exporters/executorch/utils.py +++ b/optimum/exporters/executorch/utils.py @@ -43,6 +43,7 @@ def save_config_to_constant_methods( "get_max_batch_size": 1, "get_max_seq_len": getattr(config, "max_position_embeddings", None), "decoder_start_token_id": getattr(config, "decoder_start_token_id", None), + "use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation, } # Safely access fields from generation_config if it exists diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 1678aacf..4014bf6e 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import subprocess import tempfile @@ -20,14 +21,27 @@ from pathlib import Path from tempfile import TemporaryDirectory +import torch +from executorch import version as executorch_version from executorch.extension.pybindings.portable_lib import ExecuTorchModule from huggingface_hub import HfApi +from packaging import version as pkg_version +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, +) from optimum.executorch import ExecuTorchModelForCausalLM from optimum.executorch.modeling import _FILE_PATTERN from optimum.exporters.executorch import main_export from optimum.utils.file_utils import find_files_matching_pattern +from ..utils import check_causal_lm_output_quality + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + class ExecuTorchModelIntegrationTest(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -97,3 +111,53 @@ def test_find_files_matching_pattern(self): api.snapshot_download(repo_id=model_id, local_dir=local_dir, revision=revision) pte_files = find_files_matching_pattern(local_dir, pattern=_FILE_PATTERN, revision=revision) self.assertTrue(len(pte_files) == 0 if revision == "main" else len(pte_files) > 0) + + def test_export_with_custom_sdpa(self): + if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"): + self.skipTest(reason="This test requires executorch >= 0.6 to run.") + + model_id = "optimum-internal-testing/tiny-random-llama" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch \ + --model {model_id} \ + --task 'text-generation' \ + --recipe 'xnnpack' \ + --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + def test_eager_text_generation_with_custom_sdpa(self): + if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"): + self.skipTest(reason="This test requires executorch >= 0.6 to run.") + + model_id = "HuggingFaceTB/SmolLM2-135M" + prompt = "My favourite condiment is " + max_seq_len = 32 + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # Eager model + custom sdpa + cache_implementation = "static" + attn_implementation = "custom_sdpa" + eager_model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_seq_len, + cache_config={ + "batch_size": 1, + "max_cache_len": max_seq_len, + }, + ), + ) + self.assertTrue(eager_model.config._attn_implementation, attn_implementation) + eager_inputs = tokenizer(prompt, return_tensors="pt").to(eager_model.device) + eager_generated_ids = eager_model.generate(**eager_inputs, max_new_tokens=max_seq_len, temperature=0) + eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0] + logging.info(f"\nEager generated text:\n\t{eager_generated_text}") + self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_ids)) diff --git a/tests/models/test_modeling_llama.py b/tests/models/test_modeling_llama.py index 1cd1a55a..0f4d5125 100644 --- a/tests/models/test_modeling_llama.py +++ b/tests/models/test_modeling_llama.py @@ -21,7 +21,9 @@ import unittest import pytest +from executorch import version as executorch_version from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from packaging import version as pkg_version from transformers import AutoTokenizer from transformers.testing_utils import slow @@ -73,3 +75,34 @@ def test_llama3_2_1b_text_generation(self): gc.collect() self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) + + @slow + @pytest.mark.run_slow + def test_llama_text_generation_with_custom_sdpa(self): + if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"): + self.skipTest(reason="This test requires executorch >= 0.6 to run.") + + # ExecuTorch model + custom sdpa + model_id = "NousResearch/Llama-3.2-1B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_id, + recipe="xnnpack", + attn_implementation="custom_sdpa", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=32, + ) + logging.info(f"\nGenerated text:\n\t{generated_text}") + generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids + + # Free memory before loading eager for quality check + del model + del tokenizer + gc.collect() + + self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) diff --git a/tests/models/test_modeling_olmo.py b/tests/models/test_modeling_olmo.py index 67742954..44d37c8d 100644 --- a/tests/models/test_modeling_olmo.py +++ b/tests/models/test_modeling_olmo.py @@ -21,7 +21,9 @@ import unittest import pytest +from executorch import version as executorch_version from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from packaging import version as pkg_version from transformers import AutoTokenizer from transformers.testing_utils import slow @@ -71,3 +73,34 @@ def test_olmo_text_generation_with_xnnpack(self): gc.collect() self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) + + @slow + @pytest.mark.run_slow + def test_olmo_text_generation_with_custom_sdpa(self): + if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"): + self.skipTest(reason="This test requires executorch >= 0.6 to run.") + + # ExecuTorch model + custom sdpa + model_id = "allenai/OLMo-1B-hf" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_id, + recipe="xnnpack", + attn_implementation="custom_sdpa", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="My favourite condiment is ", + max_seq_len=32, + ) + logging.info(f"\nGenerated text:\n\t{generated_text}") + generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids + + # Free memory before loading eager for quality check + del model + del tokenizer + gc.collect() + + self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) diff --git a/tests/models/test_modeling_qwen2.py b/tests/models/test_modeling_qwen2.py index 388cb61f..26dd9afd 100644 --- a/tests/models/test_modeling_qwen2.py +++ b/tests/models/test_modeling_qwen2.py @@ -21,7 +21,9 @@ import unittest import pytest +from executorch import version as executorch_version from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from packaging import version as pkg_version from transformers import AutoTokenizer from transformers.testing_utils import slow @@ -30,6 +32,9 @@ from ..utils import check_causal_lm_output_quality +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + class ExecuTorchModelIntegrationTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -71,3 +76,37 @@ def test_qwen2_5_text_generation(self): gc.collect() self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) + + @slow + @pytest.mark.run_slow + def test_qwen2_5_text_generation_with_custom_sdpa(self): + if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"): + self.skipTest(reason="This test requires executorch >= 0.6 to run.") + + model_id = "Qwen/Qwen2.5-0.5B" + prompt = "My favourite condiment is " + max_seq_len = 32 + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # ExecuTorch model + custom sdpa + model = ExecuTorchModelForCausalLM.from_pretrained( + model_id, + recipe="xnnpack", + attn_implementation="custom_sdpa", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt=prompt, + max_seq_len=max_seq_len, + ) + logging.info(f"\nGenerated text:\n\t{generated_text}") + generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids + + # Free memory before loading eager for quality check + del model + del tokenizer + gc.collect() + + self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) diff --git a/tests/models/test_modeling_smollm.py b/tests/models/test_modeling_smollm.py index 3c0f3a5b..ca5bb882 100644 --- a/tests/models/test_modeling_smollm.py +++ b/tests/models/test_modeling_smollm.py @@ -21,7 +21,9 @@ import unittest import pytest +from executorch import version as executorch_version from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from packaging import version as pkg_version from transformers import AutoTokenizer from transformers.testing_utils import slow @@ -30,6 +32,9 @@ from ..utils import check_causal_lm_output_quality +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + class ExecuTorchModelIntegrationTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -71,3 +76,38 @@ def test_smollm_text_generation(self): gc.collect() self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) + + @slow + @pytest.mark.run_slow + def test_smollm_text_generation_with_custom_sdpa(self): + if pkg_version.parse(executorch_version.__version__) < pkg_version.parse("0.6.0"): + self.skipTest(reason="This test requires executorch >= 0.6 to run.") + + model_id = "HuggingFaceTB/SmolLM2-135M" + prompt = "My favourite condiment is " + max_seq_len = 32 + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # ExecuTorch model + custom sdpa + model = ExecuTorchModelForCausalLM.from_pretrained( + model_id, + recipe="xnnpack", + attn_implementation="custom_sdpa", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt=prompt, + max_seq_len=max_seq_len, + ) + logging.info(f"\nGenerated text:\n\t{generated_text}") + generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids + + # Free memory before loading eager for quality check + del model + del tokenizer + gc.collect() + + self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))