Skip to content

Commit fc5aec8

Browse files
MousiusAshutosh Parkhi
andcommitted
Add support for named outputs in MLF archive
Following from #12789, this adds support for determining the output tensor name from the input model within the MLF metadata json. Co-authored-by: Ashutosh Parkhi <[email protected]>
1 parent 49ed544 commit fc5aec8

File tree

2 files changed

+131
-61
lines changed

2 files changed

+131
-61
lines changed

python/tvm/micro/model_library_format.py

Lines changed: 48 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,28 @@ def _create_type_metadata(input_type):
216216
"dtype": str(input_type.dtype),
217217
}
218218

219+
def _flatten_tuple_outputs(ret_type, predefined_names, offset=0):
220+
if isinstance(ret_type, tvm.ir.tensor_type.TensorType):
221+
name = predefined_names[offset] if predefined_names else f"output{offset}"
222+
return {
223+
name: ret_type
224+
}
225+
226+
added_fields = len(ret_type.fields)
227+
outputs = {}
228+
for output_index in range(added_fields):
229+
next_output = offset + len(outputs)
230+
outputs.update(_flatten_tuple_outputs(ret_type.fields[output_index], predefined_names, next_output))
231+
232+
return outputs
233+
234+
def _get_outputs_from_ret_type(ret_type, predefined_names):
235+
if isinstance(ret_type, tvm.ir.tensor_type.TensorType):
236+
name = predefined_names[0] if predefined_names else "output"
237+
return {
238+
name: ret_type
239+
}
240+
return _flatten_tuple_outputs(ret_type, predefined_names)
219241

220242
def _build_function_memory_map(function_metadata):
221243
"""Build a simple map that shows how much workspace is required to execute
@@ -297,29 +319,21 @@ def _create_empty_entry(target_device_type):
297319
target_main_entries[int(target.get_target_device_type())] = _create_empty_entry(
298320
int(target.get_target_device_type())
299321
)
300-
target_main_entries[int(target.get_target_device_type())]["io_size_bytes"] = int(
322+
target_main_on_device = target_main_entries[int(target.get_target_device_type())]
323+
target_main_on_device["io_size_bytes"] = int(
301324
main_func_metadata.io_sizes[target]
302325
)
303326

304-
# Now, we also add the information about the size of each input and output of the main
305-
# function (in bytes)
306-
input_dict = {}
307-
for input_param in main_func_metadata.relay_primfuncs[target].params:
308-
input_dict[input_param.name_hint] = _create_type_metadata(input_param.checked_type)
309-
target_main_entries[int(target.get_target_device_type())]["inputs"] = input_dict
310-
311-
output_dict = {}
312-
# For output, we dont have the name of the output, so we enumerate them
313-
if isinstance(main_func_metadata.relay_primfuncs[target].ret_type, tvm.ir.type.TupleType):
314-
output_list = _convert_tuple_to_outputs(
315-
main_func_metadata.relay_primfuncs[target].ret_type
316-
)
317-
for i, output_type in enumerate(output_list):
318-
output_dict[f"output{i}"] = _create_type_metadata(output_type)
319-
else:
320-
output_type = main_func_metadata.relay_primfuncs[target].ret_type
321-
output_dict["output"] = _create_type_metadata(output_type)
322-
target_main_entries[int(target.get_target_device_type())]["outputs"] = output_dict
327+
main_relay_func = main_func_metadata.relay_primfuncs[target]
328+
target_main_on_device["inputs"] = {
329+
input_param.name_hint: _create_type_metadata(input_param.checked_type)
330+
for input_param in main_relay_func.params
331+
}
332+
predefined_names = main_relay_func.attrs["output_tensor_names"] if "output_tensor_names" in main_relay_func.attrs else None
333+
target_main_on_device["outputs"] = {
334+
name: _create_type_metadata(output_type)
335+
for name, output_type in _get_outputs_from_ret_type(main_relay_func.ret_type, predefined_names).items()
336+
}
323337

324338
ret = {
325339
"operator_functions": func_entries,
@@ -328,30 +342,6 @@ def _create_empty_entry(target_device_type):
328342
return ret
329343

330344

331-
def _get_main_relay_func(mod: executor_factory.ExecutorFactoryModule):
332-
main_func = mod.function_metadata[MAIN_FUNC_NAME_STR]
333-
target = list(main_func.relay_primfuncs.keys())[0]
334-
return main_func.relay_primfuncs[target]
335-
336-
337-
def _convert_tuple_to_outputs(ret_type, offset=0):
338-
outputs = []
339-
added_fields = len(ret_type.fields)
340-
for output_index in range(added_fields):
341-
next_output = offset + len(outputs)
342-
if isinstance(ret_type.fields[output_index], TupleType):
343-
outputs.extend(_convert_tuple_to_outputs(ret_type.fields[output_index], next_output))
344-
else:
345-
outputs.append(ret_type.fields[output_index])
346-
return outputs
347-
348-
349-
def _get_inputs_and_outputs_from_module(mod):
350-
inputs = [str(input_var.name) for input_var in mod.executor_codegen_metadata.inputs]
351-
outputs = list(mod.executor_codegen_metadata.outputs)
352-
return inputs, outputs
353-
354-
355345
def _get_pools_from_module(mod):
356346
return list(dict(mod.executor_codegen_metadata.pool_inputs).values())
357347

@@ -462,33 +452,30 @@ def _export_graph_model_library_format(
462452
if not include_path.exists():
463453
include_path.mkdir()
464454

465-
inputs, outputs = _get_inputs_and_outputs_from_module(mod)
466455
devices = mod.get_devices()
467456
pools = _get_pools_from_module(mod)
468457
io_pool_allocations = _get_io_pool_allocation_from_module(mod)
458+
main_func = metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0]
469459
workspace_size = int(
470-
metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][
460+
main_func[
471461
"workspace_size_bytes"
472462
]
473463
)
474-
inputs_sizes = metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][
475-
"inputs"
476-
]
477-
# Here, we merge the output sizes with the actual output names
478-
output_sizes = {}
479-
for i, key in enumerate(
480-
metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][
481-
"outputs"
482-
].keys()
483-
):
484-
output_sizes[outputs[i]] = metadata["modules"][mod.libmod_name]["memory"][
485-
"functions"
486-
]["main"][0]["outputs"][key]
464+
inputs = main_func["inputs"]
465+
outputs = main_func["outputs"]
466+
inputs_sizes = {
467+
name: property_map["size"] for name, property_map in inputs.items()
468+
}
469+
output_sizes = {
470+
name: property_map["size"] for name, property_map in outputs.items()
471+
}
472+
input_names = list(inputs.keys())
473+
output_names = list(outputs.keys())
487474

488475
generate_c_interface_header(
489476
mod.libmod_name,
490-
inputs,
491-
outputs,
477+
input_names,
478+
output_names,
492479
pools,
493480
io_pool_allocations,
494481
devices,

tests/python/unittest/test_micro_model_library_format.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,89 @@ def test_multiple_relay_modules_aot_graph():
631631
assert metadata["modules"]["mod2"]["executors"] == ["aot"]
632632
assert metadata["version"] == _GENERATED_VERSION
633633

634+
@tvm.testing.requires_micro
635+
def test_output_name_single():
636+
"""Generate a conv2d Relay module for testing."""
637+
input_a = tvm.relay.var("input_a", shape=(3, 4, 5), dtype="int64")
638+
output_1 = input_a + tvm.relay.const(1, "int64")
639+
attrs = tvm.ir.make_node("DictAttrs", output_tensor_names = ["test_output_a"])
640+
main_func = tvm.relay.Function([input_a], output_1, attrs=attrs)
641+
mod = tvm.IRModule.from_expr(main_func)
642+
mod = tvm.relay.transform.InferType()(mod)
643+
644+
executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"})
645+
runtime = Runtime("crt")
646+
target = tvm.target.target.micro("host")
647+
648+
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
649+
factory = tvm.relay.build(
650+
mod, target, runtime=runtime, executor=executor, mod_name="mod1"
651+
)
652+
temp_dir = utils.tempdir()
653+
mlf_tar_path = temp_dir.relpath("lib.tar")
654+
655+
micro.export_model_library_format(factory, mlf_tar_path)
656+
657+
tf = tarfile.open(mlf_tar_path)
658+
extract_dir = temp_dir.relpath("extract")
659+
os.mkdir(extract_dir)
660+
tf.extractall(extract_dir)
661+
662+
with open(os.path.join(extract_dir, "metadata.json")) as f:
663+
metadata = json.load(f)
664+
665+
assert metadata["modules"]["mod1"]["memory"]["functions"]["main"][0]["outputs"] == {
666+
"test_output_a": {"size": 480, "dtype": "int64"}
667+
}
668+
669+
670+
@tvm.testing.requires_micro
671+
def test_output_names_many():
672+
"""Generate a conv2d Relay module for testing."""
673+
input_a = tvm.relay.var("input_a", shape=(3, 4, 5), dtype="int64")
674+
input_b = tvm.relay.var("input_b", shape=(3, 4), dtype="int32")
675+
input_c = tvm.relay.var("input_c", shape=(3,), dtype="float32")
676+
677+
output_1 = input_a + tvm.relay.const(1, "int64")
678+
output_2 = input_b + tvm.relay.const(2)
679+
output_3 = input_b + tvm.relay.const(3)
680+
output_4 = input_c + tvm.relay.const(4.0)
681+
682+
full_output = tvm.relay.Tuple(
683+
[output_1, tvm.relay.Tuple([tvm.relay.Tuple([output_2, output_3]), output_4])]
684+
)
685+
attrs = tvm.ir.make_node("DictAttrs", output_tensor_names = ["test_output_a", "test_output_b", "test_output_c", "test_output_d"])
686+
main_func = tvm.relay.Function([input_a, input_b, input_c], full_output, attrs=attrs)
687+
mod = tvm.IRModule.from_expr(main_func)
688+
mod = tvm.relay.transform.InferType()(mod)
689+
690+
executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"})
691+
runtime = Runtime("crt")
692+
target = tvm.target.target.micro("host")
693+
694+
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
695+
factory = tvm.relay.build(
696+
mod, target, runtime=runtime, executor=executor, mod_name="mod1"
697+
)
698+
temp_dir = utils.tempdir()
699+
mlf_tar_path = temp_dir.relpath("lib.tar")
700+
701+
micro.export_model_library_format(factory, mlf_tar_path)
702+
703+
tf = tarfile.open(mlf_tar_path)
704+
extract_dir = temp_dir.relpath("extract")
705+
os.mkdir(extract_dir)
706+
tf.extractall(extract_dir)
707+
708+
with open(os.path.join(extract_dir, "metadata.json")) as f:
709+
metadata = json.load(f)
710+
711+
assert metadata["modules"]["mod1"]["memory"]["functions"]["main"][0]["outputs"] == {
712+
"test_output_a": {"size": 480, "dtype": "int64"},
713+
"test_output_b": {"size": 48, "dtype": "int32"},
714+
"test_output_c": {"size": 48, "dtype": "int32"},
715+
"test_output_d": {"size": 12, "dtype": "float32"}
716+
}
634717

635718
if __name__ == "__main__":
636719
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)