Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def add_compile_parser(subparsers, _, json_params):
"--dump-code",
metavar="FORMAT",
default="",
help="comma separated list of formats to export the input model, e.g. 'asm,ll,relay'.",
help="comma separated list of formats to export the input model, e.g. 'asm,ll,tir,relay'.",
)
parser.add_argument(
"--model-format",
Expand Down Expand Up @@ -254,9 +254,9 @@ def compile_model(
output_format : str
What format to use when saving the function library. Must be one of "so" or "tar".
When compiling for a remote device without a cross compiler, "tar" will likely work better.
dump_code : list, optional
dump_code : list[str], optional
Dump the generated code for the specified source types, on
the requested target.
the requested target. Choose from: ["asm", "ll", "tir", "relay"].
target_host : str, optional
The target of the host machine if host-side code
needs to be generated.
Expand Down Expand Up @@ -290,7 +290,15 @@ def compile_model(
"""
mod, params = tvmc_model.mod, tvmc_model.params

if dump_code is None:
dump_code = []
if not isinstance(dump_code, list):
dump_code = [dump_code]
dumps = {}

config = parse_configs(pass_context_configs)
if "tir" in dump_code:
config, dumps = add_tir_to_dumps(config, dumps)

tvm_target, extra_targets = target_from_cli(target, additional_target_options)
tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host)
Expand Down Expand Up @@ -366,20 +374,16 @@ def compile_model(
)

# Generate output dump files with sources
if dump_code is None:
dump_code = []
if not isinstance(dump_code, list):
dump_code = [dump_code]
dumps = {}
for source_type in dump_code:
if use_vm:
lib = graph_module.lib
if source_type == "relay":
dumps[source_type] = str(mod)
elif source_type == "tir":
dumps[source_type] = "\n".join(dumps[source_type])
else:
lib = graph_module.get_lib()
# TODO lib.get_source call have inconsistent behavior for unsupported
# formats (@leandron).
source = str(mod) if source_type == "relay" else lib.get_source(source_type)
dumps[source_type] = source
lib = graph_module.lib if use_vm else graph_module.get_lib()
# TODO lib.get_source call have inconsistent behavior for unsupported
# formats (@leandron).
dumps[source_type] = lib.get_source(source_type)

# Create a new tvmc model package object from the graph definition.
package_path = tvmc_model.export_package(
Expand Down Expand Up @@ -440,6 +444,26 @@ def build(
)


def add_tir_to_dumps(config, dumps):
"""
Creates a debug pass that dumps TIR functions as a list of strings.
"""
key = "tir"
phase = 3 # final TIR phase before codegen
dumps[key] = []

@tvm.tir.transform.prim_func_pass(opt_level=0)
def _dump_tir_pass(tir_func, _, __):
dumps[key].append(str(tir_func))
return tir_func

tir_lower_passes = config.get("tir.add_lower_pass", [])
tir_lower_passes.append((phase, _dump_tir_pass))
config["tir.add_lower_pass"] = tir_lower_passes

return config, dumps


def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."):
"""
Serialize dump files to the disk.
Expand Down
25 changes: 24 additions & 1 deletion tests/python/driver/tvmc/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@

def test_save_dumps(tmpdir_factory):
tmpdir = tmpdir_factory.mktemp("data")
dump_formats = {"relay": "fake relay", "ll": "fake llvm", "asm": "fake asm"}
dump_formats = {"relay": "fake relay", "tir": "fake tir", "ll": "fake llvm", "asm": "fake asm"}
tvmc.compiler.save_dumps("fake_module", dump_formats, dump_root=tmpdir)

assert path.exists("{}/{}".format(tmpdir, "fake_module.ll"))
assert path.exists("{}/{}".format(tmpdir, "fake_module.asm"))
assert path.exists("{}/{}".format(tmpdir, "fake_module.tir"))
assert path.exists("{}/{}".format(tmpdir, "fake_module.relay"))


Expand Down Expand Up @@ -87,6 +88,28 @@ def test_compile_tflite_module(use_vm, tflite_mobilenet_v1_1_quant):
verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, use_vm=use_vm)


def test_single_tir_dump(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")
tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant)
tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="tir")
dumps_path = tvmc_package.package_path + ".tir"
assert os.path.exists(dumps_path)
with open(dumps_path) as f:
assert "tir" in f.read()


def test_code_dumps(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")
tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant)
dump_code = ["asm", "ll", "tir", "relay"]
tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code=dump_code)
for ext in dump_code:
dumps_path = tvmc_package.package_path + "." + ext
assert os.path.exists(dumps_path)
with open(dumps_path) as f:
assert len(f.read()) > 0


# This test will be skipped if the AArch64 cross-compilation toolchain is not installed.
@pytest.mark.skipif(
not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed"
Expand Down