From ab3ef949200ed97fc539a785f70151e70dc4586c Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Wed, 15 Oct 2025 13:21:44 +0200 Subject: [PATCH] Arm backend: Reduce complexity of get_model_and_inputs_from_name Change-Id: Icef6d5ae312e34268295f705cbc42b7c46aa12cf Signed-off-by: Sebastian Larsson --- examples/arm/aot_arm_compiler.py | 223 +++++++++++++++++++++---------- 1 file changed, 150 insertions(+), 73 deletions(-) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 34ed7e3f1bd..db248f0bf56 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -71,76 +71,151 @@ logging.basicConfig(level=logging.WARNING, format=FORMAT) -def get_model_and_inputs_from_name( - model_name: str, model_input: str | None -) -> Tuple[torch.nn.Module, Any]: - """Given the name of an example pytorch model, return it and example inputs. +def _load_example_inputs(model_input: str | None) -> Any: + """Load example inputs from a `.pt` file when a path is provided.""" + if model_input is None: + return None + + logging.info(f"Load model input from {model_input}") + + if model_input.endswith(".pt"): + return torch.load(model_input, weights_only=False) + + raise RuntimeError( + f"Model input data '{model_input}' is not a valid name. Use --model_input " + ".pt e.g. saved with torch.save()" + ) + + +def _load_internal_model( + model_name: str, example_inputs: Any +) -> Optional[Tuple[torch.nn.Module, Any]]: + """Load a bundled example model from the internal `MODELS` mapping.""" + if model_name not in MODELS: + return None + + logging.info(f"Internal model {model_name}") + + model = MODELS[model_name]() + inputs = ( + example_inputs + if example_inputs is not None + else MODELS[model_name].example_input + ) + + return model, inputs + + +def _load_registered_model( + model_name: str, example_inputs: Any +) -> Optional[Tuple[torch.nn.Module, Any]]: + """Load a registered example model from `examples.models`.""" + if model_name not in MODEL_NAME_TO_MODEL: + return None + + logging.warning( + "Using a model from examples/models not all of these are currently supported" + ) + logging.info( + f"Load {model_name} -> {MODEL_NAME_TO_MODEL[model_name]} from examples/models" + ) + + model, tmp_example_inputs, _, _ = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[model_name] + ) + inputs = example_inputs if example_inputs is not None else tmp_example_inputs + + return model, inputs + + +def _load_python_module_model( + model_name: str, example_inputs: Any +) -> Optional[Tuple[torch.nn.Module, Any]]: + """Load a model and inputs from a Python source file. + + The file must define `ModelUnderTest` and `ModelInputs` attributes. - Raises RuntimeError if there is no example model corresponding to the given name. """ - example_inputs = None - if model_input is not None: - logging.info(f"Load model input from {model_input}") - if model_input.endswith(".pt"): - example_inputs = torch.load(model_input, weights_only=False) - else: - raise RuntimeError( - f"Model input data '{model_input}' is not a valid name. Use --model_input .pt e.g. saved with torch.save()" - ) + if not model_name.endswith(".py"): + return None - # Case 1: Model is defined in this file - if model_name in models.keys(): - logging.info(f"Internal model {model_name}") - model = models[model_name]() - if example_inputs is None: - example_inputs = models[model_name].example_input - # Case 2: Model is defined in examples/models/ - elif model_name in MODEL_NAME_TO_MODEL.keys(): - logging.warning( - "Using a model from examples/models not all of these are currently supported" - ) - logging.info( - f"Load {model_name} -> {MODEL_NAME_TO_MODEL[model_name]} from examples/models" - ) + logging.info( + f"Load model file {model_name} " + "Variable ModelUnderTest= ModelInputs=" + ) - model, tmp_example_inputs, _, _ = EagerModelFactory.create_model( - *MODEL_NAME_TO_MODEL[model_name] - ) - if example_inputs is None: - example_inputs = tmp_example_inputs - # Case 3: Model is in an external python file loaded as a module. - # ModelUnderTest should be a torch.nn.module instance - # ModelInputs should be a tuple of inputs to the forward function - elif model_name.endswith(".py"): - logging.info( - f"Load model file {model_name} Variable ModelUnderTest= ModelInputs=" - ) - import importlib.util - - # load model's module and add it - spec = importlib.util.spec_from_file_location("tmp_model", model_name) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - model = module.ModelUnderTest - if example_inputs is None: - example_inputs = module.ModelInputs - # Case 4: Model is in an saved model file torch.save(model) - elif model_name.endswith(".pth") or model_name.endswith(".pt"): - logging.info(f"Load model file {model_name}") - model = torch.load(model_name, weights_only=False) - if example_inputs is None: - raise RuntimeError( - f"Model '{model_name}' requires input data specify --model_input .pt" - ) - else: + import importlib.util + + spec = importlib.util.spec_from_file_location("tmp_model", model_name) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load model file {model_name}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + model = module.ModelUnderTest + inputs = example_inputs if example_inputs is not None else module.ModelInputs + + return model, inputs + + +def _load_serialized_model( + model_name: str, example_inputs: Any +) -> Optional[Tuple[torch.nn.Module, Any]]: + """Load a serialized Torch model saved via `torch.save`.""" + if not model_name.endswith((".pth", ".pt")): + return None + + logging.info(f"Load model file {model_name}") + + model = torch.load(model_name, weights_only=False) + if example_inputs is None: raise RuntimeError( - f"Model '{model_name}' is not a valid name. Use --help for a list of available models." + f"Model '{model_name}' requires input data specify --model_input .pt" ) - logging.debug(f"Loaded model: {model}") - logging.debug(f"Loaded input: {example_inputs}") + return model, example_inputs +def get_model_and_inputs_from_name( + model_name: str, model_input: str | None +) -> Tuple[torch.nn.Module, Any]: + """Resolve a model name into a model instance and example inputs. + + Args: + model_name: Identifier for the model. It can be a key in + `MODEL_NAME_TO_MODEL`, a Python module path, or a serialized + model file path. + model_input: Optional path to a `.pt` file containing example inputs. + + Returns: + Tuple of `(model, example_inputs)` ready for compilation. + + Raises: + RuntimeError: If the model cannot be resolved or required inputs are + missing. + + """ + example_inputs = _load_example_inputs(model_input) + + loaders = ( + _load_internal_model, + _load_registered_model, + _load_python_module_model, + _load_serialized_model, + ) + + for loader in loaders: + result = loader(model_name, example_inputs) + if result is not None: + model, example_inputs = result + logging.debug(f"Loaded model: {model}") + logging.debug(f"Loaded input: {example_inputs}") + return model, example_inputs + + raise RuntimeError( + f"Model '{model_name}' is not a valid name. Use --help for a list of available models." + ) + + def quantize( model: GraphModule, model_name: str, @@ -150,7 +225,9 @@ def quantize( evaluator_config: Dict[str, Any] | None, ) -> GraphModule: """This is the official recommended flow for quantization in pytorch 2.0 - export""" + export. + + """ logging.info("Quantizing Model...") logging.debug(f"Original model: {model}") @@ -238,7 +315,7 @@ def forward(self, x): can_delegate = True -models = { +MODELS = { "qadd": QuantAddTest, "qadd2": QuantAddTest2, "qops": QuantOpTest, @@ -247,7 +324,7 @@ def forward(self, x): "qlinear": QuantLinearTest, } -calibration_data = { +CALIBRATION_DATA = { "qadd": (torch.randn(32, 2, 1),), "qadd2": ( torch.randn(32, 2, 1), @@ -261,7 +338,7 @@ def forward(self, x): ), } -targets = [ +TARGETS = [ "ethos-u55-32", "ethos-u55-64", "ethos-u55-128", @@ -289,10 +366,10 @@ def get_calibration_data( if evaluator_data is not None: return evaluator_data - # If the model is in the calibration_data dictionary, get the data from there + # If the model is in the CALIBRATION_DATA dictionary, get the data from there # This is used for the simple model examples provided - if model_name in calibration_data: - return calibration_data[model_name] + if model_name in CALIBRATION_DATA: + return CALIBRATION_DATA[model_name] # As a last resort, fallback to the scripts previous behavior and return the example inputs return example_inputs @@ -365,7 +442,7 @@ def get_args(): "-m", "--model_name", required=True, - help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(models.keys())+list(MODEL_NAME_TO_MODEL.keys()))}", + help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(MODELS.keys()) + list(MODEL_NAME_TO_MODEL.keys()))}", ) parser.add_argument( "--model_input", @@ -401,8 +478,8 @@ def get_args(): action="store", required=False, default="ethos-u55-128", - choices=targets, - help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}", + choices=TARGETS, + help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {TARGETS}", ) parser.add_argument( "-e", @@ -506,9 +583,9 @@ def get_args(): torch.ops.load_library(args.so_library) if ( - args.model_name in models.keys() + args.model_name in MODELS.keys() and args.delegate is True - and models[args.model_name].can_delegate is False + and MODELS[args.model_name].can_delegate is False ): raise RuntimeError(f"Model {args.model_name} cannot be delegated.")