Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 51 additions & 68 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import typing

import tvm
from tvm.ir.type import TupleType
from tvm.micro import get_standalone_crt_dir
from .._ffi import get_global_func
from ..contrib import utils
Expand Down Expand Up @@ -217,6 +216,29 @@ def _create_type_metadata(input_type):
}


def _flatten_tuple_outputs(ret_type, predefined_names, offset=0):
if isinstance(ret_type, tvm.ir.tensor_type.TensorType):
name = predefined_names[offset] if predefined_names else f"output{offset}"
return {name: ret_type}

added_fields = len(ret_type.fields)
outputs = {}
for output_index in range(added_fields):
next_output = offset + len(outputs)
outputs.update(
_flatten_tuple_outputs(ret_type.fields[output_index], predefined_names, next_output)
)

return outputs


def _get_outputs_from_ret_type(ret_type, predefined_names):
if isinstance(ret_type, tvm.ir.tensor_type.TensorType):
name = predefined_names[0] if predefined_names else "output"
return {name: ret_type}
return _flatten_tuple_outputs(ret_type, predefined_names)


def _build_function_memory_map(function_metadata):
"""Build a simple map that shows how much workspace is required to execute
each primitive function. The main_func describes how much memory is required
Expand Down Expand Up @@ -297,29 +319,25 @@ def _create_empty_entry(target_device_type):
target_main_entries[int(target.get_target_device_type())] = _create_empty_entry(
int(target.get_target_device_type())
)
target_main_entries[int(target.get_target_device_type())]["io_size_bytes"] = int(
main_func_metadata.io_sizes[target]
)
target_main_on_device = target_main_entries[int(target.get_target_device_type())]
target_main_on_device["io_size_bytes"] = int(main_func_metadata.io_sizes[target])

# Now, we also add the information about the size of each input and output of the main
# function (in bytes)
input_dict = {}
for input_param in main_func_metadata.relay_primfuncs[target].params:
input_dict[input_param.name_hint] = _create_type_metadata(input_param.checked_type)
target_main_entries[int(target.get_target_device_type())]["inputs"] = input_dict

output_dict = {}
# For output, we dont have the name of the output, so we enumerate them
if isinstance(main_func_metadata.relay_primfuncs[target].ret_type, tvm.ir.type.TupleType):
output_list = _convert_tuple_to_outputs(
main_func_metadata.relay_primfuncs[target].ret_type
)
for i, output_type in enumerate(output_list):
output_dict[f"output{i}"] = _create_type_metadata(output_type)
else:
output_type = main_func_metadata.relay_primfuncs[target].ret_type
output_dict["output"] = _create_type_metadata(output_type)
target_main_entries[int(target.get_target_device_type())]["outputs"] = output_dict
main_relay_func = main_func_metadata.relay_primfuncs[target]
target_main_on_device["inputs"] = {
input_param.name_hint: _create_type_metadata(input_param.checked_type)
for input_param in main_relay_func.params
}
predefined_names = (
main_relay_func.attrs["output_tensor_names"]
if "output_tensor_names" in main_relay_func.attrs
else None
)
target_main_on_device["outputs"] = {
name: _create_type_metadata(output_type)
for name, output_type in _get_outputs_from_ret_type(
main_relay_func.ret_type, predefined_names
).items()
}

ret = {
"operator_functions": func_entries,
Expand All @@ -328,30 +346,6 @@ def _create_empty_entry(target_device_type):
return ret


def _get_main_relay_func(mod: executor_factory.ExecutorFactoryModule):
main_func = mod.function_metadata[MAIN_FUNC_NAME_STR]
target = list(main_func.relay_primfuncs.keys())[0]
return main_func.relay_primfuncs[target]


def _convert_tuple_to_outputs(ret_type, offset=0):
outputs = []
added_fields = len(ret_type.fields)
for output_index in range(added_fields):
next_output = offset + len(outputs)
if isinstance(ret_type.fields[output_index], TupleType):
outputs.extend(_convert_tuple_to_outputs(ret_type.fields[output_index], next_output))
else:
outputs.append(ret_type.fields[output_index])
return outputs


def _get_inputs_and_outputs_from_module(mod):
inputs = [str(input_var.name) for input_var in mod.executor_codegen_metadata.inputs]
outputs = list(mod.executor_codegen_metadata.outputs)
return inputs, outputs


def _get_pools_from_module(mod):
return list(dict(mod.executor_codegen_metadata.pool_inputs).values())

Expand Down Expand Up @@ -462,33 +456,22 @@ def _export_graph_model_library_format(
if not include_path.exists():
include_path.mkdir()

inputs, outputs = _get_inputs_and_outputs_from_module(mod)
devices = mod.get_devices()
pools = _get_pools_from_module(mod)
io_pool_allocations = _get_io_pool_allocation_from_module(mod)
workspace_size = int(
metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][
"workspace_size_bytes"
]
)
inputs_sizes = metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][
"inputs"
]
# Here, we merge the output sizes with the actual output names
output_sizes = {}
for i, key in enumerate(
metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][
"outputs"
].keys()
):
output_sizes[outputs[i]] = metadata["modules"][mod.libmod_name]["memory"][
"functions"
]["main"][0]["outputs"][key]
main_func = metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0]
workspace_size = int(main_func["workspace_size_bytes"])
inputs = main_func["inputs"]
outputs = main_func["outputs"]
inputs_sizes = {name: property_map["size"] for name, property_map in inputs.items()}
output_sizes = {name: property_map["size"] for name, property_map in outputs.items()}
input_names = list(inputs.keys())
output_names = list(outputs.keys())

generate_c_interface_header(
mod.libmod_name,
inputs,
outputs,
input_names,
output_names,
pools,
io_pool_allocations,
devices,
Expand Down
84 changes: 84 additions & 0 deletions tests/python/unittest/test_micro_model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,5 +632,89 @@ def test_multiple_relay_modules_aot_graph():
assert metadata["version"] == _GENERATED_VERSION


@tvm.testing.requires_micro
def test_output_name_single():
"""Generate a conv2d Relay module for testing."""
input_a = tvm.relay.var("input_a", shape=(3, 4, 5), dtype="int64")
output_1 = input_a + tvm.relay.const(1, "int64")
attrs = tvm.ir.make_node("DictAttrs", output_tensor_names=["test_output_a"])
main_func = tvm.relay.Function([input_a], output_1, attrs=attrs)
mod = tvm.IRModule.from_expr(main_func)
mod = tvm.relay.transform.InferType()(mod)

executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"})
runtime = Runtime("crt")
target = tvm.target.target.micro("host")

with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
factory = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod1")
temp_dir = utils.tempdir()
mlf_tar_path = temp_dir.relpath("lib.tar")

micro.export_model_library_format(factory, mlf_tar_path)

tf = tarfile.open(mlf_tar_path)
extract_dir = temp_dir.relpath("extract")
os.mkdir(extract_dir)
tf.extractall(extract_dir)

with open(os.path.join(extract_dir, "metadata.json")) as f:
metadata = json.load(f)

assert metadata["modules"]["mod1"]["memory"]["functions"]["main"][0]["outputs"] == {
"test_output_a": {"size": 480, "dtype": "int64"}
}


@tvm.testing.requires_micro
def test_output_names_many():
"""Generate a conv2d Relay module for testing."""
input_a = tvm.relay.var("input_a", shape=(3, 4, 5), dtype="int64")
input_b = tvm.relay.var("input_b", shape=(3, 4), dtype="int32")
input_c = tvm.relay.var("input_c", shape=(3,), dtype="float32")

output_1 = input_a + tvm.relay.const(1, "int64")
output_2 = input_b + tvm.relay.const(2)
output_3 = input_b + tvm.relay.const(3)
output_4 = input_c + tvm.relay.const(4.0)

full_output = tvm.relay.Tuple(
[output_1, tvm.relay.Tuple([tvm.relay.Tuple([output_2, output_3]), output_4])]
)
attrs = tvm.ir.make_node(
"DictAttrs",
output_tensor_names=["test_output_a", "test_output_b", "test_output_c", "test_output_d"],
)
main_func = tvm.relay.Function([input_a, input_b, input_c], full_output, attrs=attrs)
mod = tvm.IRModule.from_expr(main_func)
mod = tvm.relay.transform.InferType()(mod)

executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"})
runtime = Runtime("crt")
target = tvm.target.target.micro("host")

with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
factory = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod1")
temp_dir = utils.tempdir()
mlf_tar_path = temp_dir.relpath("lib.tar")

micro.export_model_library_format(factory, mlf_tar_path)

tf = tarfile.open(mlf_tar_path)
extract_dir = temp_dir.relpath("extract")
os.mkdir(extract_dir)
tf.extractall(extract_dir)

with open(os.path.join(extract_dir, "metadata.json")) as f:
metadata = json.load(f)

assert metadata["modules"]["mod1"]["memory"]["functions"]["main"][0]["outputs"] == {
"test_output_a": {"size": 480, "dtype": "int64"},
"test_output_b": {"size": 48, "dtype": "int32"},
"test_output_c": {"size": 48, "dtype": "int32"},
"test_output_d": {"size": 12, "dtype": "float32"},
}


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this to tvm.testing.main()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like it'd be better done as a blanket change rather than polluting many individual patches?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either way is fine. We don't have a lint step to catch these, that's why I've trying to notify people to fix it as I see them in the PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not introduce unrelated changes into my patch, I can try a mass find and replace in a future patch? 😸