import tvm
from tvm import relay

import onnx


def load_onnx_model(model: str):
    onnx_model = onnx.load(model)
    input_shapes = {'input_1': (1, 2, 3)}
    mod, params = relay.frontend.from_onnx(onnx_model, input_shapes, freeze_params=True)

    usmp_enable = True
    aot_opts = {
        "runtime": tvm.relay.backend.Runtime("crt"),
        "executor": tvm.relay.backend.Executor(name="aot", options={"unpacked-api": 1}),
    }

    build_config = {
        "tir.disable_assert": True,
        "tir.disable_vectorize": True,
        "tir.usmp.enable": usmp_enable,
        "tir.usmp.algorithm": "hill_climb",
    }

    with tvm.transform.PassContext(opt_level=3, config=build_config):
        module = relay.build(
            mod,
            executor=aot_opts["executor"],
            target=tvm.target.Target(
                "c",
                host="c",
            ),
            runtime=aot_opts["runtime"],
            params={},
        )
    source_modules = []

    def traverse_mod_tree(m, indent=0):
        if m.type_key in ("c", "llvm"):
            source_modules.append(m.get_source())
        for i in m.imported_modules:
            traverse_mod_tree(i, indent + 2)

    traverse_mod_tree(module.module)

    for idx, module in enumerate(source_modules):
        with open(f"./sources/module_{idx}.txt", "w") as f:
            f.write(module)


load_onnx_model("./lstm_1_dense_in_3_out_2_ts_2.onnx")
