diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index eec80820cdb1..2c98085d7dfe 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -71,7 +71,7 @@ def add_compile_parser(subparsers, _, json_params): "--dump-code", metavar="FORMAT", default="", - help="comma separated list of formats to export the input model, e.g. 'asm,ll,relay'.", + help="comma separated list of formats to export the input model, e.g. 'asm,ll,tir,relay'.", ) parser.add_argument( "--model-format", @@ -254,9 +254,9 @@ def compile_model( output_format : str What format to use when saving the function library. Must be one of "so" or "tar". When compiling for a remote device without a cross compiler, "tar" will likely work better. - dump_code : list, optional + dump_code : list[str], optional Dump the generated code for the specified source types, on - the requested target. + the requested target. Choose from: ["asm", "ll", "tir", "relay"]. target_host : str, optional The target of the host machine if host-side code needs to be generated. @@ -290,7 +290,15 @@ def compile_model( """ mod, params = tvmc_model.mod, tvmc_model.params + if dump_code is None: + dump_code = [] + if not isinstance(dump_code, list): + dump_code = [dump_code] + dumps = {} + config = parse_configs(pass_context_configs) + if "tir" in dump_code: + config, dumps = add_tir_to_dumps(config, dumps) tvm_target, extra_targets = target_from_cli(target, additional_target_options) tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host) @@ -366,20 +374,16 @@ def compile_model( ) # Generate output dump files with sources - if dump_code is None: - dump_code = [] - if not isinstance(dump_code, list): - dump_code = [dump_code] - dumps = {} for source_type in dump_code: - if use_vm: - lib = graph_module.lib + if source_type == "relay": + dumps[source_type] = str(mod) + elif source_type == "tir": + dumps[source_type] = "\n".join(dumps[source_type]) else: - lib = graph_module.get_lib() - # TODO lib.get_source call have inconsistent behavior for unsupported - # formats (@leandron). - source = str(mod) if source_type == "relay" else lib.get_source(source_type) - dumps[source_type] = source + lib = graph_module.lib if use_vm else graph_module.get_lib() + # TODO lib.get_source call have inconsistent behavior for unsupported + # formats (@leandron). + dumps[source_type] = lib.get_source(source_type) # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( @@ -440,6 +444,26 @@ def build( ) +def add_tir_to_dumps(config, dumps): + """ + Creates a debug pass that dumps TIR functions as a list of strings. + """ + key = "tir" + phase = 3 # final TIR phase before codegen + dumps[key] = [] + + @tvm.tir.transform.prim_func_pass(opt_level=0) + def _dump_tir_pass(tir_func, _, __): + dumps[key].append(str(tir_func)) + return tir_func + + tir_lower_passes = config.get("tir.add_lower_pass", []) + tir_lower_passes.append((phase, _dump_tir_pass)) + config["tir.add_lower_pass"] = tir_lower_passes + + return config, dumps + + def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."): """ Serialize dump files to the disk. diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 3a3f297729fd..6bcf19056df3 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -40,11 +40,12 @@ def test_save_dumps(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("data") - dump_formats = {"relay": "fake relay", "ll": "fake llvm", "asm": "fake asm"} + dump_formats = {"relay": "fake relay", "tir": "fake tir", "ll": "fake llvm", "asm": "fake asm"} tvmc.compiler.save_dumps("fake_module", dump_formats, dump_root=tmpdir) assert path.exists("{}/{}".format(tmpdir, "fake_module.ll")) assert path.exists("{}/{}".format(tmpdir, "fake_module.asm")) + assert path.exists("{}/{}".format(tmpdir, "fake_module.tir")) assert path.exists("{}/{}".format(tmpdir, "fake_module.relay")) @@ -87,6 +88,28 @@ def test_compile_tflite_module(use_vm, tflite_mobilenet_v1_1_quant): verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, use_vm=use_vm) +def test_single_tir_dump(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="tir") + dumps_path = tvmc_package.package_path + ".tir" + assert os.path.exists(dumps_path) + with open(dumps_path) as f: + assert "tir" in f.read() + + +def test_code_dumps(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + dump_code = ["asm", "ll", "tir", "relay"] + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code=dump_code) + for ext in dump_code: + dumps_path = tvmc_package.package_path + "." + ext + assert os.path.exists(dumps_path) + with open(dumps_path) as f: + assert len(f.read()) > 0 + + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed"