Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,35 @@
class TestLlama(unittest.TestCase):
"""
Test class of Llama models. Type of Llama model depends on command line parameters:
--llama_inputs <path to .pt file> <path to json file>
Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json
--llama_inputs <path to .pt file> <path to json file> <name of model variant>
Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json stories110m
For more examples and info see examples/models/llama/README.md.
"""

def prepare_model(self):

checkpoint = None
params_file = None
usage = "To run use --llama_inputs <.pt/.pth> <.json> <name>"

if conftest.is_option_enabled("llama_inputs"):
param_list = conftest.get_option("llama_inputs")
assert (
isinstance(param_list, list) and len(param_list) == 2
), "invalid number of inputs for --llama_inputs"

if not isinstance(param_list, list) or len(param_list) != 3:
raise RuntimeError(
f"Invalid number of inputs for --llama_inputs. {usage}"
)
if not all(isinstance(param, str) for param in param_list):
raise RuntimeError(
f"All --llama_inputs are expected to be strings. {usage}"
)

checkpoint = param_list[0]
params_file = param_list[1]
assert isinstance(checkpoint, str) and isinstance(
params_file, str
), "invalid input for --llama_inputs"
model_name = param_list[2]
else:
logger.warning(
"Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>"
"Skipping Llama tests because of missing --llama_inputs. {usage}"
)
return None, None, None

Expand All @@ -71,7 +79,7 @@ def prepare_model(self):
"-p",
params_file,
"--model",
"stories110m",
model_name,
]
parser = build_args_parser()
args = parser.parse_args(args)
Expand Down Expand Up @@ -122,6 +130,7 @@ def test_llama_tosa_BI(self):
.quantize()
.export()
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(
inputs=llama_inputs,
Expand Down
Loading