Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 150 additions & 73 deletions examples/arm/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,76 +71,151 @@
logging.basicConfig(level=logging.WARNING, format=FORMAT)


def get_model_and_inputs_from_name(
model_name: str, model_input: str | None
) -> Tuple[torch.nn.Module, Any]:
"""Given the name of an example pytorch model, return it and example inputs.
def _load_example_inputs(model_input: str | None) -> Any:
"""Load example inputs from a `.pt` file when a path is provided."""
if model_input is None:
return None

logging.info(f"Load model input from {model_input}")

if model_input.endswith(".pt"):
return torch.load(model_input, weights_only=False)

raise RuntimeError(
f"Model input data '{model_input}' is not a valid name. Use --model_input "
"<FILE>.pt e.g. saved with torch.save()"
)


def _load_internal_model(
model_name: str, example_inputs: Any
) -> Optional[Tuple[torch.nn.Module, Any]]:
"""Load a bundled example model from the internal `MODELS` mapping."""
if model_name not in MODELS:
return None

logging.info(f"Internal model {model_name}")

model = MODELS[model_name]()
inputs = (
example_inputs
if example_inputs is not None
else MODELS[model_name].example_input
)

return model, inputs


def _load_registered_model(
model_name: str, example_inputs: Any
) -> Optional[Tuple[torch.nn.Module, Any]]:
"""Load a registered example model from `examples.models`."""
if model_name not in MODEL_NAME_TO_MODEL:
return None

logging.warning(
"Using a model from examples/models not all of these are currently supported"
)
logging.info(
f"Load {model_name} -> {MODEL_NAME_TO_MODEL[model_name]} from examples/models"
)

model, tmp_example_inputs, _, _ = EagerModelFactory.create_model(
*MODEL_NAME_TO_MODEL[model_name]
)
inputs = example_inputs if example_inputs is not None else tmp_example_inputs

return model, inputs


def _load_python_module_model(
model_name: str, example_inputs: Any
) -> Optional[Tuple[torch.nn.Module, Any]]:
"""Load a model and inputs from a Python source file.

The file must define `ModelUnderTest` and `ModelInputs` attributes.

Raises RuntimeError if there is no example model corresponding to the given name.
"""
example_inputs = None
if model_input is not None:
logging.info(f"Load model input from {model_input}")
if model_input.endswith(".pt"):
example_inputs = torch.load(model_input, weights_only=False)
else:
raise RuntimeError(
f"Model input data '{model_input}' is not a valid name. Use --model_input <FILE>.pt e.g. saved with torch.save()"
)
if not model_name.endswith(".py"):
return None

# Case 1: Model is defined in this file
if model_name in models.keys():
logging.info(f"Internal model {model_name}")
model = models[model_name]()
if example_inputs is None:
example_inputs = models[model_name].example_input
# Case 2: Model is defined in examples/models/
elif model_name in MODEL_NAME_TO_MODEL.keys():
logging.warning(
"Using a model from examples/models not all of these are currently supported"
)
logging.info(
f"Load {model_name} -> {MODEL_NAME_TO_MODEL[model_name]} from examples/models"
)
logging.info(
f"Load model file {model_name} "
"Variable ModelUnderTest=<Model> ModelInputs=<ModelInput>"
)

model, tmp_example_inputs, _, _ = EagerModelFactory.create_model(
*MODEL_NAME_TO_MODEL[model_name]
)
if example_inputs is None:
example_inputs = tmp_example_inputs
# Case 3: Model is in an external python file loaded as a module.
# ModelUnderTest should be a torch.nn.module instance
# ModelInputs should be a tuple of inputs to the forward function
elif model_name.endswith(".py"):
logging.info(
f"Load model file {model_name} Variable ModelUnderTest=<Model> ModelInputs=<ModelInput>"
)
import importlib.util

# load model's module and add it
spec = importlib.util.spec_from_file_location("tmp_model", model_name)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
model = module.ModelUnderTest
if example_inputs is None:
example_inputs = module.ModelInputs
# Case 4: Model is in an saved model file torch.save(model)
elif model_name.endswith(".pth") or model_name.endswith(".pt"):
logging.info(f"Load model file {model_name}")
model = torch.load(model_name, weights_only=False)
if example_inputs is None:
raise RuntimeError(
f"Model '{model_name}' requires input data specify --model_input <FILE>.pt"
)
else:
import importlib.util

spec = importlib.util.spec_from_file_location("tmp_model", model_name)
if spec is None or spec.loader is None:
raise RuntimeError(f"Unable to load model file {model_name}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
model = module.ModelUnderTest
inputs = example_inputs if example_inputs is not None else module.ModelInputs

return model, inputs


def _load_serialized_model(
model_name: str, example_inputs: Any
) -> Optional[Tuple[torch.nn.Module, Any]]:
"""Load a serialized Torch model saved via `torch.save`."""
if not model_name.endswith((".pth", ".pt")):
return None

logging.info(f"Load model file {model_name}")

model = torch.load(model_name, weights_only=False)
if example_inputs is None:
raise RuntimeError(
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
f"Model '{model_name}' requires input data specify --model_input <FILE>.pt"
)
logging.debug(f"Loaded model: {model}")
logging.debug(f"Loaded input: {example_inputs}")

return model, example_inputs


def get_model_and_inputs_from_name(
model_name: str, model_input: str | None
) -> Tuple[torch.nn.Module, Any]:
"""Resolve a model name into a model instance and example inputs.

Args:
model_name: Identifier for the model. It can be a key in
`MODEL_NAME_TO_MODEL`, a Python module path, or a serialized
model file path.
model_input: Optional path to a `.pt` file containing example inputs.

Returns:
Tuple of `(model, example_inputs)` ready for compilation.

Raises:
RuntimeError: If the model cannot be resolved or required inputs are
missing.

"""
example_inputs = _load_example_inputs(model_input)

loaders = (
_load_internal_model,
_load_registered_model,
_load_python_module_model,
_load_serialized_model,
)

for loader in loaders:
result = loader(model_name, example_inputs)
if result is not None:
model, example_inputs = result
logging.debug(f"Loaded model: {model}")
logging.debug(f"Loaded input: {example_inputs}")
return model, example_inputs

raise RuntimeError(
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
)


def quantize(
model: GraphModule,
model_name: str,
Expand All @@ -150,7 +225,9 @@ def quantize(
evaluator_config: Dict[str, Any] | None,
) -> GraphModule:
"""This is the official recommended flow for quantization in pytorch 2.0
export"""
export.

"""
logging.info("Quantizing Model...")
logging.debug(f"Original model: {model}")

Expand Down Expand Up @@ -238,7 +315,7 @@ def forward(self, x):
can_delegate = True


models = {
MODELS = {
"qadd": QuantAddTest,
"qadd2": QuantAddTest2,
"qops": QuantOpTest,
Expand All @@ -247,7 +324,7 @@ def forward(self, x):
"qlinear": QuantLinearTest,
}

calibration_data = {
CALIBRATION_DATA = {
"qadd": (torch.randn(32, 2, 1),),
"qadd2": (
torch.randn(32, 2, 1),
Expand All @@ -261,7 +338,7 @@ def forward(self, x):
),
}

targets = [
TARGETS = [
"ethos-u55-32",
"ethos-u55-64",
"ethos-u55-128",
Expand Down Expand Up @@ -289,10 +366,10 @@ def get_calibration_data(
if evaluator_data is not None:
return evaluator_data

# If the model is in the calibration_data dictionary, get the data from there
# If the model is in the CALIBRATION_DATA dictionary, get the data from there
# This is used for the simple model examples provided
if model_name in calibration_data:
return calibration_data[model_name]
if model_name in CALIBRATION_DATA:
return CALIBRATION_DATA[model_name]

# As a last resort, fallback to the scripts previous behavior and return the example inputs
return example_inputs
Expand Down Expand Up @@ -365,7 +442,7 @@ def get_args():
"-m",
"--model_name",
required=True,
help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(models.keys())+list(MODEL_NAME_TO_MODEL.keys()))}",
help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(MODELS.keys()) + list(MODEL_NAME_TO_MODEL.keys()))}",
)
parser.add_argument(
"--model_input",
Expand Down Expand Up @@ -401,8 +478,8 @@ def get_args():
action="store",
required=False,
default="ethos-u55-128",
choices=targets,
help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}",
choices=TARGETS,
help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {TARGETS}",
)
parser.add_argument(
"-e",
Expand Down Expand Up @@ -506,9 +583,9 @@ def get_args():
torch.ops.load_library(args.so_library)

if (
args.model_name in models.keys()
args.model_name in MODELS.keys()
and args.delegate is True
and models[args.model_name].can_delegate is False
and MODELS[args.model_name].can_delegate is False
):
raise RuntimeError(f"Model {args.model_name} cannot be delegated.")

Expand Down
Loading