-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add "operator" style to Model Library Format #8072
Changes from all commits
51e841c
76ef467
66126f4
010f8ff
d0aa180
408d154
2537d3a
c04e8f3
9fff102
b007dfa
5c20bd6
7b1ef1a
493b13d
32abcf2
be680af
ed3008d
c200ef5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,12 +20,18 @@ | |
import datetime | ||
import json | ||
import os | ||
import pathlib | ||
import re | ||
import tarfile | ||
import typing | ||
|
||
from .._ffi import get_global_func | ||
from ..contrib import utils | ||
from ..driver import build_module | ||
from ..runtime import ndarray as _nd | ||
from ..relay.backend import executor_factory | ||
from ..relay import param_dict | ||
from ..tir import expr | ||
|
||
# This should be kept identical to runtime::symbol::tvm_module_main | ||
MAIN_FUNC_NAME_STR = "__tvm_main__" | ||
|
@@ -207,67 +213,207 @@ def _build_function_memory_map(function_metadata): | |
return ret | ||
|
||
|
||
def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name): | ||
"""Export the build artifact in Model Library Format. | ||
def _make_tar(source_dir, tar_file_path): | ||
"""Build a tar file from source_dir.""" | ||
with tarfile.open(tar_file_path, "w") as tar_f: | ||
|
||
This function creates a .tar archive containing the build artifacts in a standardized | ||
layout. It's intended to allow downstream automation to build TVM artifacts against the C | ||
runtime. | ||
def reset(tarinfo): | ||
tarinfo.uid = tarinfo.gid = 0 | ||
tarinfo.uname = tarinfo.gname = "root" | ||
return tarinfo | ||
|
||
tar_f.add(str(source_dir), arcname=".", filter=reset) | ||
|
||
|
||
_GENERATED_VERSION = 4 | ||
|
||
|
||
def _export_graph_model_library_format( | ||
mod: executor_factory.ExecutorFactoryModule, tempdir: pathlib.Path | ||
): | ||
"""Export a tvm.relay.build artifact in Model Library Format. | ||
|
||
Parameters | ||
---------- | ||
mod : tvm.relay.backend.executor_factory.ExecutorFactoryModule | ||
The return value of tvm.relay.build, which will be exported into Model Library Format. | ||
file_name : str | ||
Path to the .tar archive to generate. | ||
|
||
Returns | ||
------- | ||
file_name : str | ||
The path to the generated .tar archive. | ||
tempdir : pathlib.Path | ||
Temporary directory to populate with Model Library Format contents. | ||
""" | ||
tempdir = utils.tempdir() | ||
is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) | ||
runtime = ["aot"] if is_aot else ["graph"] | ||
|
||
metadata = { | ||
"version": 3, | ||
"version": _GENERATED_VERSION, | ||
"model_name": mod.libmod_name, | ||
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), | ||
"memory": _build_memory_map(mod), | ||
"target": {int(k): str(v) for k, v in mod.target.items()}, | ||
"runtimes": runtime, | ||
"style": "full-model", | ||
} | ||
|
||
with open(tempdir.relpath("metadata.json"), "w") as json_f: | ||
with open(tempdir / "metadata.json", "w") as json_f: | ||
json.dump(metadata, json_f, indent=2, sort_keys=True) | ||
|
||
codegen_dir_path = tempdir.relpath("codegen") | ||
os.mkdir(codegen_dir_path) | ||
_populate_codegen_dir(mod.lib, codegen_dir_path, mod.libmod_name) | ||
codegen_dir = tempdir / "codegen" | ||
codegen_dir.mkdir() | ||
_populate_codegen_dir(mod.lib, codegen_dir, mod.libmod_name) | ||
|
||
parameters_dir_path = tempdir.relpath("parameters") | ||
os.mkdir(parameters_dir_path) | ||
param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params") | ||
parameters_dir = tempdir / "parameters" | ||
parameters_dir.mkdir() | ||
param_filename = parameters_dir / f"{mod.libmod_name}.params" | ||
with open(param_filename, "wb") as f: | ||
f.write(param_dict.save_param_dict(mod.params)) | ||
|
||
with open(tempdir.relpath("relay.txt"), "w") as f: | ||
src_dir = tempdir / "src" | ||
src_dir.mkdir() | ||
with open(src_dir / "relay.txt", "w") as f: | ||
f.write(str(mod.ir_mod)) | ||
|
||
if not is_aot: | ||
graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph")) | ||
os.makedirs(graph_config_dir_path) | ||
with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f: | ||
graph_config_dir = tempdir / "runtime-config" / "graph" | ||
graph_config_dir.mkdir(parents=True) | ||
with open(graph_config_dir / "graph.json", "w") as f: | ||
f.write(mod.get_executor_config()) | ||
|
||
with tarfile.open(file_name, "w") as tar_f: | ||
|
||
def reset(tarinfo): | ||
tarinfo.uid = tarinfo.gid = 0 | ||
tarinfo.uname = tarinfo.gname = "root" | ||
return tarinfo | ||
class NonStaticShapeError(Exception): | ||
"""Raised when a shape has elements other than IntImm.""" | ||
|
||
|
||
def _shape_to_size(shape, dtype): | ||
bits_per_item = int( | ||
re.match(r"((float)|(int))(?P<width_bits>[0-9]+)", dtype).group("width_bits") | ||
) | ||
assert bits_per_item is not None, f"don't know how to compute size of type {dtype}" | ||
total_bits = bits_per_item | ||
for s in shape: | ||
total_bits *= s | ||
|
||
return (total_bits + 7) // 8 | ||
|
||
|
||
def _write_tir_and_build_operator_memory_map(src_dir, targets, ir_module_by_target): | ||
def _eval_shape(param_name, buffer_shape): | ||
shape = [] | ||
for x in buffer_shape: | ||
if not isinstance(x, expr.IntImm): | ||
raise NonStaticShapeError( | ||
f"Parameter {param_name} has shape with non-IntImm elements: {buffer_shape}" | ||
) | ||
shape.append(x.value) | ||
return shape | ||
|
||
memory_map = {} | ||
for target_device_type, target in targets.items(): | ||
ir_mod = ir_module_by_target[target] | ||
printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False) | ||
with open(src_dir / f"tir-{target_device_type}.txt", "w") as f: | ||
f.write(printer["print"](ir_mod)) | ||
Comment on lines
+312
to
+313
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not following why adding the TIR in the archive. Is this for test purposes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's mostly an analogy to adding the relay.txt into the archive--to provide TVM source code for the generated code. though I see your point that TIR is quite close to the generated code. |
||
|
||
for v in ir_mod.get_global_vars(): | ||
map_entry = [] | ||
for p, b in ir_mod[v.name_hint].buffer_map.items(): | ||
shape = _eval_shape(p.name, b.shape) | ||
buffer_size_bytes = _shape_to_size(shape, str(b.dtype)) | ||
# NOTE: cannot tell what is an input or output at this point. | ||
map_entry.append( | ||
{ | ||
"size_bytes": buffer_size_bytes, | ||
"shape": [int(x) for x in b.shape], | ||
"dtype": b.dtype, | ||
"input_binding": printer["get_var_name"](p), | ||
} | ||
) | ||
memory_map[v.name_hint] = map_entry | ||
|
||
return memory_map | ||
|
||
|
||
def _export_operator_model_library_format(mod: build_module.OperatorModule, tempdir): | ||
"""Export the result of tvm.build() in Model Library Format. | ||
|
||
Parameters | ||
---------- | ||
mod : runtime.Module | ||
The Module returned from tvm.build(). | ||
args : list of Buffer or Tensor or Var, optional | ||
The args supplied to tvm.build(). | ||
file_name : str | ||
Path to the .tar archive to generate. | ||
""" | ||
targets = {} | ||
for target in mod.ir_module_by_target.keys(): | ||
if str(target.kind) not in ("llvm", "c"): | ||
raise UnsupportedInModelLibraryFormatError( | ||
f"Operator has non-DSO-exportable target {target!s}, which is not yet supported in " | ||
"Model Library Format" | ||
) | ||
|
||
targets[int(_nd.device(str(target)).device_type)] = target | ||
|
||
src_dir = tempdir / "src" | ||
src_dir.mkdir() | ||
memory_map = _write_tir_and_build_operator_memory_map(src_dir, targets, mod.ir_module_by_target) | ||
|
||
metadata = { | ||
"version": _GENERATED_VERSION, | ||
"model_name": mod.name, | ||
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), | ||
"memory": memory_map, | ||
"target": {k: str(v) for k, v in targets.items()}, | ||
"runtimes": [], | ||
"style": "operator", | ||
} | ||
with open(tempdir / "metadata.json", "w") as metadata_f: | ||
json.dump(metadata, metadata_f) | ||
|
||
codegen_dir = tempdir / "codegen" | ||
codegen_dir.mkdir() | ||
_populate_codegen_dir(mod, codegen_dir) | ||
|
||
|
||
ExportableModule = typing.Union[ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure whether the model_library_format.py is the right place to hold this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i guess this is a bit specific to Model Library Format--you can build shared libraries from things we don't know how to export into MLF. happy to change the name, or we can revisit this when we promote MLF to a top-level TVM export format. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack, sounds good for now then. |
||
build_module.OperatorModule, | ||
executor_factory.AOTExecutorFactoryModule, | ||
executor_factory.GraphExecutorFactoryModule, | ||
] | ||
|
||
|
||
def export_model_library_format(mod: ExportableModule, file_name: typing.Union[str, pathlib.Path]): | ||
"""Export the build artifact in Model Library Format. | ||
|
||
This function creates a .tar archive containing the build artifacts in a standardized | ||
layout. It's intended to allow downstream automation to build TVM artifacts against the C | ||
runtime. | ||
|
||
Parameters | ||
---------- | ||
mod : ExportableModule | ||
The return value of tvm.build or tvm.relay.build. | ||
file_name : str | ||
Path to the .tar archive to generate. | ||
|
||
Returns | ||
------- | ||
file_name : str | ||
The path to the generated .tar archive. | ||
""" | ||
file_name = pathlib.Path(file_name) | ||
|
||
tempdir = utils.tempdir() | ||
|
||
tar_f.add(tempdir.temp_dir, arcname=".", filter=reset) | ||
if isinstance(mod, build_module.OperatorModule): | ||
_export_operator_model_library_format(mod, tempdir.path) | ||
elif isinstance( | ||
mod, | ||
(executor_factory.AOTExecutorFactoryModule, executor_factory.GraphExecutorFactoryModule), | ||
): | ||
_export_graph_model_library_format(mod, tempdir.path) | ||
else: | ||
raise NotImplementedError(f"Don't know how to export module of type {mod.__class__!r}") | ||
|
||
_make_tar(tempdir.path, file_name) | ||
|
||
return file_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we can have this neatly hidden under _ffi_api.py and move the c++ implementations related to ModelLibraryFormatPrinter to a matching model_library_format.cc.
Why do we think ModelLibraryFormatPrinter belongs to the namespace of tir?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah good point. in
src/printer
, we have a few entry points:src/printer/tvmscript_printer.cc
definesscript.AsTVMScript
src/printer/text_printer.cc
definesir.PrettyPrint
andir.AsText
so i guess the folder doesn't provide any namespace grouping right now, even though printer implementations are consolidated there. i'm okay moving to
micro.ModelLibraryFormatPrinter
orir.ModelLibraryFormatPrinter
, if that's what you're suggesting.tir
seemed like a fit since that's how we are using it now, though it should work with any IRModule.could you let me know which namespace you're suggesting to move to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking of "micro.model_library_format.printer" being the registration and make printer a python function that binds to C++ under _ffi_api.py (similiar to how its done in CallGraph).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this case, we need the member function to retrieve the mapping--this is why i used Module. as for the namespace, i don't have a strong opinion, but the only
micro
directory we have insrc
issrc/runtime
, and this is clearly not a runtime component. so we'd need to createsrc/micro
, is all. i'm not opposed to that, but was following convention forPrinter
in keepingModelLibraryFormatPrinter
underneathsrc/printer
, is all.