Skip to content

Commit

Permalink
Add "operator" style to Model Library Format (#8072)
Browse files Browse the repository at this point in the history
* rename _update_target and document its function

* make tvm.build return OperatorModule to return multiple outputs

* allow retrieving the var names used in TIR repr

* add Operator Model Library Format and test

* Add pathlib convenience functions to utils.TempDirectory.

* fix tests

* black format

* git-clang-format

* pylint fixes

* add asf header

* change memory map to make more sense, fix tests

* address giuseros comments

* align GetVarName with future TypedPackedFunc

* fix test

* clang-format

* rev model library format to v4 (bad merge)
  • Loading branch information
areusch authored Jul 2, 2021
1 parent 354d996 commit 970aeff
Show file tree
Hide file tree
Showing 8 changed files with 399 additions and 46 deletions.
13 changes: 13 additions & 0 deletions python/tvm/contrib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import contextlib
import datetime
import os
import pathlib
import tempfile
import threading
import shutil
Expand Down Expand Up @@ -119,6 +120,18 @@ def remove(self):
self.TEMPDIRS.remove(self.temp_dir)
self.temp_dir = None

@property
def path(self):
return pathlib.Path(self.temp_dir)

def __div__(self, other):
if not isinstance(other, (str, pathlib.Path)):
raise TypeError(
"TempDirectory / operator: must supply str or pathlib.Path; got %r" % (other,)
)

return self.path / other

def __del__(self):
temp_dirs = getattr(self, "TEMPDIRS", None)
if temp_dirs is None:
Expand Down
29 changes: 25 additions & 4 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import tvm.tir

from tvm.runtime import Module
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.ir import CallingConv
Expand Down Expand Up @@ -372,12 +373,32 @@ def build(
create_csource_crt_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateCSourceCrtMetadataModule"
)
return create_csource_crt_metadata_module([rt_mod_host], target_host)
to_return = create_csource_crt_metadata_module([rt_mod_host], target_host)

if target_host.kind.name == "llvm":
elif target_host.kind.name == "llvm":
create_llvm_crt_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateLLVMCrtMetadataModule"
)
return create_llvm_crt_metadata_module([rt_mod_host], target_host)
to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host)
else:
to_return = rt_mod_host

return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name)


class OperatorModule(Module):
"""Wraps the Module returned by tvm.build() and captures additional outputs of that function."""

@classmethod
def from_module(cls, mod, **kwargs):
# NOTE(areusch): It is generally unsafe to continue using `mod` from this point forward.
# If an exception occurs in cls.__init__, handle will be deleted. For this reason,
# set mod.handle to None.
handle = mod.handle
mod.handle = None
return cls(handle, **kwargs)

return rt_mod_host
def __init__(self, handle, ir_module_by_target=None, name=None):
super(OperatorModule, self).__init__(handle)
self.ir_module_by_target = ir_module_by_target
self.name = name
208 changes: 177 additions & 31 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@
import datetime
import json
import os
import pathlib
import re
import tarfile
import typing

from .._ffi import get_global_func
from ..contrib import utils
from ..driver import build_module
from ..runtime import ndarray as _nd
from ..relay.backend import executor_factory
from ..relay import param_dict
from ..tir import expr

# This should be kept identical to runtime::symbol::tvm_module_main
MAIN_FUNC_NAME_STR = "__tvm_main__"
Expand Down Expand Up @@ -207,67 +213,207 @@ def _build_function_memory_map(function_metadata):
return ret


def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name):
"""Export the build artifact in Model Library Format.
def _make_tar(source_dir, tar_file_path):
"""Build a tar file from source_dir."""
with tarfile.open(tar_file_path, "w") as tar_f:

This function creates a .tar archive containing the build artifacts in a standardized
layout. It's intended to allow downstream automation to build TVM artifacts against the C
runtime.
def reset(tarinfo):
tarinfo.uid = tarinfo.gid = 0
tarinfo.uname = tarinfo.gname = "root"
return tarinfo

tar_f.add(str(source_dir), arcname=".", filter=reset)


_GENERATED_VERSION = 4


def _export_graph_model_library_format(
mod: executor_factory.ExecutorFactoryModule, tempdir: pathlib.Path
):
"""Export a tvm.relay.build artifact in Model Library Format.
Parameters
----------
mod : tvm.relay.backend.executor_factory.ExecutorFactoryModule
The return value of tvm.relay.build, which will be exported into Model Library Format.
file_name : str
Path to the .tar archive to generate.
Returns
-------
file_name : str
The path to the generated .tar archive.
tempdir : pathlib.Path
Temporary directory to populate with Model Library Format contents.
"""
tempdir = utils.tempdir()
is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule)
runtime = ["aot"] if is_aot else ["graph"]

metadata = {
"version": 3,
"version": _GENERATED_VERSION,
"model_name": mod.libmod_name,
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
"memory": _build_memory_map(mod),
"target": {int(k): str(v) for k, v in mod.target.items()},
"runtimes": runtime,
"style": "full-model",
}

with open(tempdir.relpath("metadata.json"), "w") as json_f:
with open(tempdir / "metadata.json", "w") as json_f:
json.dump(metadata, json_f, indent=2, sort_keys=True)

codegen_dir_path = tempdir.relpath("codegen")
os.mkdir(codegen_dir_path)
_populate_codegen_dir(mod.lib, codegen_dir_path, mod.libmod_name)
codegen_dir = tempdir / "codegen"
codegen_dir.mkdir()
_populate_codegen_dir(mod.lib, codegen_dir, mod.libmod_name)

parameters_dir_path = tempdir.relpath("parameters")
os.mkdir(parameters_dir_path)
param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params")
parameters_dir = tempdir / "parameters"
parameters_dir.mkdir()
param_filename = parameters_dir / f"{mod.libmod_name}.params"
with open(param_filename, "wb") as f:
f.write(param_dict.save_param_dict(mod.params))

with open(tempdir.relpath("relay.txt"), "w") as f:
src_dir = tempdir / "src"
src_dir.mkdir()
with open(src_dir / "relay.txt", "w") as f:
f.write(str(mod.ir_mod))

if not is_aot:
graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph"))
os.makedirs(graph_config_dir_path)
with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f:
graph_config_dir = tempdir / "runtime-config" / "graph"
graph_config_dir.mkdir(parents=True)
with open(graph_config_dir / "graph.json", "w") as f:
f.write(mod.get_executor_config())

with tarfile.open(file_name, "w") as tar_f:

def reset(tarinfo):
tarinfo.uid = tarinfo.gid = 0
tarinfo.uname = tarinfo.gname = "root"
return tarinfo
class NonStaticShapeError(Exception):
"""Raised when a shape has elements other than IntImm."""


def _shape_to_size(shape, dtype):
bits_per_item = int(
re.match(r"((float)|(int))(?P<width_bits>[0-9]+)", dtype).group("width_bits")
)
assert bits_per_item is not None, f"don't know how to compute size of type {dtype}"
total_bits = bits_per_item
for s in shape:
total_bits *= s

return (total_bits + 7) // 8


def _write_tir_and_build_operator_memory_map(src_dir, targets, ir_module_by_target):
def _eval_shape(param_name, buffer_shape):
shape = []
for x in buffer_shape:
if not isinstance(x, expr.IntImm):
raise NonStaticShapeError(
f"Parameter {param_name} has shape with non-IntImm elements: {buffer_shape}"
)
shape.append(x.value)
return shape

memory_map = {}
for target_device_type, target in targets.items():
ir_mod = ir_module_by_target[target]
printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False)
with open(src_dir / f"tir-{target_device_type}.txt", "w") as f:
f.write(printer["print"](ir_mod))

for v in ir_mod.get_global_vars():
map_entry = []
for p, b in ir_mod[v.name_hint].buffer_map.items():
shape = _eval_shape(p.name, b.shape)
buffer_size_bytes = _shape_to_size(shape, str(b.dtype))
# NOTE: cannot tell what is an input or output at this point.
map_entry.append(
{
"size_bytes": buffer_size_bytes,
"shape": [int(x) for x in b.shape],
"dtype": b.dtype,
"input_binding": printer["get_var_name"](p),
}
)
memory_map[v.name_hint] = map_entry

return memory_map


def _export_operator_model_library_format(mod: build_module.OperatorModule, tempdir):
"""Export the result of tvm.build() in Model Library Format.
Parameters
----------
mod : runtime.Module
The Module returned from tvm.build().
args : list of Buffer or Tensor or Var, optional
The args supplied to tvm.build().
file_name : str
Path to the .tar archive to generate.
"""
targets = {}
for target in mod.ir_module_by_target.keys():
if str(target.kind) not in ("llvm", "c"):
raise UnsupportedInModelLibraryFormatError(
f"Operator has non-DSO-exportable target {target!s}, which is not yet supported in "
"Model Library Format"
)

targets[int(_nd.device(str(target)).device_type)] = target

src_dir = tempdir / "src"
src_dir.mkdir()
memory_map = _write_tir_and_build_operator_memory_map(src_dir, targets, mod.ir_module_by_target)

metadata = {
"version": _GENERATED_VERSION,
"model_name": mod.name,
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
"memory": memory_map,
"target": {k: str(v) for k, v in targets.items()},
"runtimes": [],
"style": "operator",
}
with open(tempdir / "metadata.json", "w") as metadata_f:
json.dump(metadata, metadata_f)

codegen_dir = tempdir / "codegen"
codegen_dir.mkdir()
_populate_codegen_dir(mod, codegen_dir)


ExportableModule = typing.Union[
build_module.OperatorModule,
executor_factory.AOTExecutorFactoryModule,
executor_factory.GraphExecutorFactoryModule,
]


def export_model_library_format(mod: ExportableModule, file_name: typing.Union[str, pathlib.Path]):
"""Export the build artifact in Model Library Format.
This function creates a .tar archive containing the build artifacts in a standardized
layout. It's intended to allow downstream automation to build TVM artifacts against the C
runtime.
Parameters
----------
mod : ExportableModule
The return value of tvm.build or tvm.relay.build.
file_name : str
Path to the .tar archive to generate.
Returns
-------
file_name : str
The path to the generated .tar archive.
"""
file_name = pathlib.Path(file_name)

tempdir = utils.tempdir()

tar_f.add(tempdir.temp_dir, arcname=".", filter=reset)
if isinstance(mod, build_module.OperatorModule):
_export_operator_model_library_format(mod, tempdir.path)
elif isinstance(
mod,
(executor_factory.AOTExecutorFactoryModule, executor_factory.GraphExecutorFactoryModule),
):
_export_graph_model_library_format(mod, tempdir.path)
else:
raise NotImplementedError(f"Don't know how to export module of type {mod.__class__!r}")

_make_tar(tempdir.path, file_name)

return file_name
26 changes: 21 additions & 5 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,23 @@
from .backend.vm import VMExecutor


def _update_target(target):
def build_target_by_device_type_map(target):
"""Build a map from DLDevice device_type to a Target used with that device.
At runtime, TVM assigns target code to DLDevices by determining a device_type for each Target.
This function handles this process at compile time and, as a side effect, validates that exactly
one target maps to one device_type.
Parameters
----------
target : Target or str or dict
If a Target or str: assumes that exactly one device type is present in the model.
If a dict: keys are tvm.ndarray.device, values are the targets used for each device.
Returns
-------
"""
target = target if target else Target.current()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
Expand Down Expand Up @@ -132,7 +148,7 @@ def build(
params : dict
The parameters of the final graph.
"""
target = _update_target(target)
target = build_target_by_device_type_map(target)
target, target_host = Target.check_and_update_host_consist(
target, target_host, target_is_dict_key=False
)
Expand Down Expand Up @@ -187,7 +203,7 @@ def optimize(self, mod, target=None, params=None):
params : dict
The parameters of the final graph.
"""
target = _update_target(target)
target = build_target_by_device_type_map(target)

# Setup the params.
if params:
Expand Down Expand Up @@ -316,7 +332,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
"instead of deprecated parameter mod (tvm.relay.function.Function)",
DeprecationWarning,
)
target = _update_target(target)
target = build_target_by_device_type_map(target)
if isinstance(target_host, (str, Target)):
target_host = Target(target_host)
elif target_host:
Expand Down Expand Up @@ -395,7 +411,7 @@ def optimize(mod, target=None, params=None):
DeprecationWarning,
)

target = _update_target(target)
target = build_target_by_device_type_map(target)

# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
Expand Down
Loading

0 comments on commit 970aeff

Please sign in to comment.