Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 58 additions & 25 deletions examples/arm/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,28 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s
save_bundled_program(exec_prog, method_test_suites, output_name)


def quantize_model(
exported_program, args, model: torch.nn.Module, example_inputs, compile_spec
):
model_int8 = quantize(
model,
args.model_name,
compile_spec,
example_inputs,
args.evaluate,
args.evaluate_config,
)
# Wrap quantized model back into an exported_program
exported_program = torch.export.export_for_training(
model_int8, example_inputs, strict=True
)

return model_int8, exported_program


def to_edge_TOSA_delegate(
exported_program,
args,
model: torch.nn.Module,
exported_program, args, model: torch.nn.Module, example_inputs
):
model_int8 = None
# As we can target multiple output encodings, one must
# be specified.
compile_spec = get_compile_spec(
Expand All @@ -634,23 +650,13 @@ def to_edge_TOSA_delegate(
args.system_config,
args.memory_mode,
)

model_int8 = None
if args.quantize:
model = quantize(
model,
args.model_name,
compile_spec,
example_inputs,
args.evaluate,
args.evaluate_config,
model_int8, exported_program = quantize_model(
exported_program, args, model, example_inputs, compile_spec
)
model_int8 = model
# Wrap quantized model back into an exported_program
exported_program = torch.export.export_for_training(
model, example_inputs, strict=True
)

if args.intermediates:
os.makedirs(args.intermediates, exist_ok=True)
model = model_int8

if is_ethosu(compile_spec):
partitioner = EthosUPartitioner(compile_spec)
Expand All @@ -669,6 +675,31 @@ def to_edge_TOSA_delegate(
return model_int8, edge


def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_inputs):
model_int8 = None
if args.quantize:
# As we can target multiple output encodings, one must
# be specified.
compile_spec = get_compile_spec(
args.target,
args.intermediates,
args.system_config,
args.memory_mode,
)
model, exported_program = quantize_model(
exported_program, args, model, example_inputs, compile_spec
)
model_int8 = model

edge = to_edge_transform_and_lower(
exported_program,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
return model_int8, edge


if __name__ == "__main__": # noqa: C901
args = get_args()

Expand All @@ -686,16 +717,18 @@ def to_edge_TOSA_delegate(
model = exported_program.module()
model_fp32 = model

if args.intermediates:
os.makedirs(args.intermediates, exist_ok=True)

# Quantize if required
model_int8 = None
if args.delegate:
model_int8, edge = to_edge_TOSA_delegate(exported_program, args, model)
model_int8, edge = to_edge_TOSA_delegate(
exported_program, args, model, example_inputs
)
else:
edge = to_edge_transform_and_lower(
exported_program,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
model_int8, edge = to_edge_no_delegate(
exported_program, args, model, example_inputs
)

dump_delegation_info(edge, args.intermediates)
Expand Down
Loading