Skip to content

Commit efeccea

Browse files
committed
refactor param binding
1 parent 109187f commit efeccea

File tree

4 files changed

+22
-21
lines changed

4 files changed

+22
-21
lines changed

src/relay/backend/build_module.cc

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,7 @@ class RelayBuildModule : public runtime::ModuleNode {
333333
IRModule OptimizeImpl(IRModule relay_module) {
334334
ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler.";
335335

336-
if (!params_.empty()) {
337-
ICHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function";
338-
GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
339-
Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
340-
auto new_main = BindParamsByName(main_func, params_);
341-
IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite();
342-
relay_module_ptr->Update(main_glb_var, new_main);
343-
}
336+
backend::BindParamsInModule(relay_module, params_);
344337

345338
Array<Pass> pass_seqs = GetPassPrefix(
346339
/*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false);

src/relay/backend/metaschedule_task_extraction.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ namespace metaschedule {
3535
using meta_schedule::ExtractedTask;
3636

3737
Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Constant> params) {
38+
// backend::BindParamsInModule(mod, params);
3839
if (params.size()) {
3940
std::unordered_map<std::string, runtime::NDArray> params_;
4041
BaseFunc base_func = mod->Lookup("main");
@@ -51,18 +52,13 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
5152
auto opt_mod = seq(std::move(mod));
5253

5354
Array<ExtractedTask> tasks;
54-
LOG(INFO) << opt_mod;
55-
LOG(INFO) << opt_mod->Lookup("main");
5655
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) {
5756
if (exp->IsInstance<FunctionNode>()) {
5857
Function relay_func = Downcast<Function>(exp);
5958
if (relay_func->HasNonzeroAttr(attr::kPrimitive)) {
60-
LOG(INFO) << relay_func;
6159
Array<te::Tensor> outputs;
6260
std::string fused_name;
6361
std::tie(outputs, fused_name) = tec::LowerTECompute(target, relay_func);
64-
LOG(INFO) << fused_name;
65-
LOG(INFO) << outputs;
6662
auto prim_func = tir::CreatePrimFunc(outputs);
6763
auto prim_fn_var = GlobalVar(fused_name);
6864
auto relay_mod = IRModule({{prim_fn_var, relay_func}});

src/relay/backend/utils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,25 @@ inline relay::Function BindParamsByName(
417417
return ret;
418418
}
419419

420+
inline void BindParamsInModule(IRModule mod,
421+
const std::unordered_map<std::string, runtime::NDArray>& params) {
422+
if (!params.empty()) {
423+
BaseFunc base_func = mod->Lookup("main");
424+
ICHECK(base_func->IsInstance<FunctionNode>());
425+
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params);
426+
auto gvar = mod->GetGlobalVar("main");
427+
mod->Add(gvar, f);
428+
}
429+
}
430+
431+
inline void BindParamsInModule(IRModule mod, Map<String, Constant> params) {
432+
std::unordered_map<std::string, runtime::NDArray> params_tmp;
433+
for (const auto& kv : params) {
434+
params_tmp[kv.first] = kv.second->data;
435+
}
436+
BindParamsInModule(mod, params_tmp);
437+
}
438+
420439
/*!
421440
* \brief Extract the shape from a Relay tensor type.
422441
* \param type The provided type.

src/relay/backend/vm/compiler.cc

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,14 +1034,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets,
10341034

10351035
IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
10361036
VLOG_CONTEXT << "VM Optimize";
1037-
if (params_.size()) {
1038-
BaseFunc base_func = mod->Lookup("main");
1039-
ICHECK(base_func->IsInstance<FunctionNode>())
1040-
<< "VM compiler expects to compile relay::Function";
1041-
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
1042-
auto gvar = mod->GetGlobalVar("main");
1043-
mod->Add(gvar, f);
1044-
}
1037+
backend::BindParamsInModule(mod, params_);
10451038

10461039
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(
10471040
/*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true);

0 commit comments

Comments
 (0)