diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 190669aa7a6c..e6f4a1eaee2c 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -26,27 +26,47 @@ namespace printer { TVM_REGISTER_NODE_TYPE(IRFrameNode); +struct SortableFunction { + int priority; + GlobalVar gv; + BaseFunc func; + + explicit SortableFunction(const std::pair& obj) + : priority(0), gv(obj.first), func(obj.second) { + if (gv->name_hint == "main") { + priority = 1000; + } else if (obj.second->GetTypeKey() == "tir.PrimFunc") { + priority = 1; + } else if (obj.second->GetTypeKey() == "relax.expr.ExternFunc") { + priority = 2; + } else if (obj.second->GetTypeKey() == "relax.expr.Function") { + priority = 3; + } else { + LOG(FATAL) << "TypeError: TVMScript cannot print functions of type: " + << obj.second->GetTypeKey(); + } + } + + bool operator<(const SortableFunction& other) const { + if (this->priority != other.priority) { + return this->priority < other.priority; + } + return this->gv->name_hint < other.gv->name_hint; + } +}; + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc { - std::vector> functions{mod->functions.begin(), - mod->functions.end()}; - // print "main" first - std::sort(functions.begin(), functions.end(), [](const auto& lhs, const auto& rhs) { - String lhs_name = lhs.first->name_hint; - String rhs_name = rhs.first->name_hint; - if (lhs_name == "main") { - lhs_name = ""; - } - if (rhs_name == "main") { - rhs_name = ""; - } - return lhs_name < rhs_name; - }); + std::vector functions; + for (const auto& kv : mod->functions) { + functions.push_back(SortableFunction(kv)); + } + std::sort(functions.begin(), functions.end()); With f(d); (*f)->AddDispatchToken(d, "ir"); - for (const auto& kv : functions) { - GlobalVar gv = kv.first; - BaseFunc func = kv.second; + for (const auto& entry : functions) { + const GlobalVar& gv = entry.gv; + const BaseFunc& func = entry.func; d->cfg->binding_names.push_back(gv->name_hint); Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv)); d->cfg->binding_names.pop_back();