diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index 44a8fdc2a04..f5d879b3b8b 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -33,27 +33,35 @@ class TestLlama(unittest.TestCase): """ Test class of Llama models. Type of Llama model depends on command line parameters: - --llama_inputs - Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json + --llama_inputs + Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json stories110m + For more examples and info see examples/models/llama/README.md. """ def prepare_model(self): checkpoint = None params_file = None + usage = "To run use --llama_inputs <.pt/.pth> <.json> " + if conftest.is_option_enabled("llama_inputs"): param_list = conftest.get_option("llama_inputs") - assert ( - isinstance(param_list, list) and len(param_list) == 2 - ), "invalid number of inputs for --llama_inputs" + + if not isinstance(param_list, list) or len(param_list) != 3: + raise RuntimeError( + f"Invalid number of inputs for --llama_inputs. {usage}" + ) + if not all(isinstance(param, str) for param in param_list): + raise RuntimeError( + f"All --llama_inputs are expected to be strings. {usage}" + ) + checkpoint = param_list[0] params_file = param_list[1] - assert isinstance(checkpoint, str) and isinstance( - params_file, str - ), "invalid input for --llama_inputs" + model_name = param_list[2] else: logger.warning( - "Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>" + "Skipping Llama tests because of missing --llama_inputs. {usage}" ) return None, None, None @@ -71,7 +79,7 @@ def prepare_model(self): "-p", params_file, "--model", - "stories110m", + model_name, ] parser = build_args_parser() args = parser.parse_args(args) @@ -122,6 +130,7 @@ def test_llama_tosa_BI(self): .quantize() .export() .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs( inputs=llama_inputs,