diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 2a5412a5671f..8fd87a6304dd 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -249,7 +249,8 @@ class IRModuleNode : public Object { TVM_DLL GlobalVar GetGlobalVar(const String& str) const; /*! - * \brief Collect all global vars defined in this module. + * \brief Collect all global vars defined in this module, ordered by + * the global variable name. * \returns An array of global vars */ TVM_DLL Array GetGlobalVars() const; diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index b7b3f411ed41..4953c1c81701 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -475,10 +475,10 @@ def export_tvm( ------- irmodule : tvm.ir.IRModule The converted tvm IR representation of the model. - params : Dict[str, tvm.nd.array] - A dictionary of parameters corresponding to the weights of - the model. + params : List[Tuple[str, Parameter]] + A list of Parameters corresponding to the weights of the model. ext_mods : List[nn.ExternModule] + A list of ExternModules that are used in the model. """ # pylint: disable=import-outside-toplevel from . import spec as _spec diff --git a/src/ir/module.cc b/src/ir/module.cc index 2e60441e94d3..261fbfe087c6 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -183,6 +184,9 @@ tvm::Array IRModuleNode::GetGlobalVars() const { for (const auto& pair : global_var_map_) { global_vars.push_back(pair.second); } + std::sort(global_vars.begin(), global_vars.end(), [](const GlobalVar& lhs, const GlobalVar& rhs) { + return lhs->name_hint < rhs->name_hint; + }); return tvm::Array(global_vars); } diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 8b5518212cc8..2cb226d56e27 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -89,7 +89,8 @@ class AlterOpImplMutator : public ExprMutator { op_buffer_axis_separators__(axis_separators_) {} IRModule Run() { - for (const auto& [gv, func] : mod_->functions) { + for (const auto& gv : mod_->GetGlobalVars()) { + const auto& func = mod_->Lookup(gv); if (func->IsInstance()) { relax::Function update_func = Downcast(VisitExpr(func)); builder_->UpdateFunction(gv, update_func); diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 28c7d74ef8d0..876c714c61e3 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -148,7 +148,8 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array ent for (const auto& name : entry_function_names) { entry_functions.insert(mod->GetGlobalVar(name)); } - for (const auto& [gv, func] : mod->functions) { + for (const auto& gv : mod->GetGlobalVars()) { + const auto& func = mod->Lookup(gv); if (func.as() || func->GetLinkageType() == LinkageType::kExternal) { entry_functions.insert(gv); } diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index a2a3e96dd567..3e762778d849 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -691,7 +691,8 @@ class OperatorFusor : public ExprMutator { * \return The new IRModule after transformation */ IRModule Transform() { - for (const auto& [gv, func] : mod_->functions) { + for (const auto& gv : mod_->GetGlobalVars()) { + const auto& func = mod_->Lookup(gv); // Only visit Relax function without attr kPrimitive. if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { auto updated_func = Downcast(VisitExpr(func)); @@ -1196,9 +1197,9 @@ class CompositeFunctionAnnotator : public ExprMutator { IRModule Run() { auto mod = builder_->GetContextIRModule(); - auto all_functions = mod->functions; - for (const auto& entry : all_functions) { - if (const auto* func = entry.second.as()) { + for (const auto& gv : mod->GetGlobalVars()) { + const auto& base_func = mod->Lookup(gv); + if (const auto* func = base_func.as()) { if (func->GetAttr(attr::kComposite).defined() || func->GetAttr(attr::kCodegen).defined()) { continue; @@ -1208,7 +1209,7 @@ class CompositeFunctionAnnotator : public ExprMutator { if (!new_body.same_as(func->body)) { auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs, func->span); - builder_->UpdateFunction(entry.first, new_func); + builder_->UpdateFunction(gv, new_func); } } } @@ -1272,11 +1273,12 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, support::Arena arena; for (const auto& pattern : patterns) { OperatorFusor::GroupMap group_map; - for (const auto& entry : mod->functions) { - if (entry.second->IsInstance()) { + for (const auto& gv : mod->GetGlobalVars()) { + const auto& base_func = mod->Lookup(gv); + if (base_func->IsInstance()) { continue; } - const FunctionNode* function = entry.second.as(); + const FunctionNode* function = base_func.as(); if (function->GetAttr(attr::kPrimitive).defined() || function->GetAttr(attr::kComposite).defined() || function->GetAttr(attr::kCodegen).defined()) { @@ -1285,8 +1287,8 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern, pattern->annotation_patterns, - pattern->check.value_or(nullptr), entry.second, - &arena, pattern->attrs_getter.value_or(nullptr)); + pattern->check.value_or(nullptr), base_func, &arena, + pattern->attrs_getter.value_or(nullptr)); for (const auto& [key, value] : map) { CHECK(!group_map.count(key)) << "ValueError: " diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 11785ab73ac6..3df17b29ca52 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -964,7 +964,8 @@ class TIRFuseMutator : public ExprMutator { static IRModule Transform(IRModule mod) { // Collect all primitive relax functions Map primitive_relax; - for (const auto& [gvar, base_func] : mod->functions) { + for (const auto& gvar : mod->GetGlobalVars()) { + const auto& base_func = mod->Lookup(gvar); // Only fuse primitive relax functions if (base_func->HasNonzeroAttr(attr::kPrimitive)) { if (auto func = base_func.as()) { diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 343c18acd7a9..e2e463ff2b2f 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -67,16 +67,18 @@ class LegalizeMutator : public ExprMutator { } IRModule Transform() { - for (const auto& [gv, func] : mod_->functions) { + for (const auto& gv : mod_->GetGlobalVars()) { + const auto& func = mod_->Lookup(gv); if (func->IsInstance()) { auto updated_func = Downcast(this->VisitExpr(func)); builder_->UpdateFunction(gv, Downcast(updated_func)); } } // Fill the "kTarget" attribute of PrimFunc - for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) { + const auto& mod = builder_->GetContextIRModule(); + for (const auto& gv : mod->GetGlobalVars()) { const tir::PrimFuncNode* prim_func; - if (tmap_.count(gv) && (prim_func = func.as())) { + if (tmap_.count(gv) && (prim_func = mod->Lookup(gv).as())) { auto f = WithAttr(GetRef(prim_func), tvm::attr::kTarget, tmap_[gv]); builder_->UpdateFunction(gv, f); }