Skip to content

Commit 4a5e4aa

Browse files
committed
move BindParams function to cc file
1 parent efeccea commit 4a5e4aa

File tree

2 files changed

+55
-47
lines changed

2 files changed

+55
-47
lines changed

src/relay/backend/utils.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,56 @@ std::vector<int64_t> ShapeToJSON(tvm::Array<IndexExpr> shape) {
308308
return ret;
309309
}
310310

311+
relay::Function BindParamsByName(relay::Function func,
312+
const std::unordered_map<std::string, runtime::NDArray>& params) {
313+
std::unordered_map<std::string, relay::Var> name_dict;
314+
std::unordered_set<relay::Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
315+
for (auto arg : func->params) {
316+
const auto& name = arg->name_hint();
317+
if (name_dict.count(name)) {
318+
repeat_var.insert(name_dict[name]);
319+
} else {
320+
name_dict[name] = arg;
321+
}
322+
}
323+
324+
std::unordered_map<relay::Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
325+
for (auto& kv : params) {
326+
if (name_dict.count(kv.first) == 0) {
327+
continue;
328+
}
329+
auto arg = name_dict.at(kv.first);
330+
if (repeat_var.count(arg)) {
331+
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
332+
}
333+
bind_dict[arg] = Constant(kv.second);
334+
}
335+
Expr bound_expr = relay::Bind(func, bind_dict);
336+
Function ret = Downcast<Function>(bound_expr);
337+
ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
338+
<< "\n";
339+
return ret;
340+
}
341+
342+
void BindParamsInModule(IRModule mod,
343+
const std::unordered_map<std::string, runtime::NDArray>& params) {
344+
if (!params.empty()) {
345+
BaseFunc base_func = mod->Lookup("main");
346+
ICHECK(base_func->IsInstance<FunctionNode>());
347+
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params);
348+
auto gvar = mod->GetGlobalVar("main");
349+
mod->Add(gvar, f);
350+
}
351+
}
352+
353+
void BindParamsInModule(IRModule mod, Map<String, Constant> params) {
354+
std::unordered_map<std::string, runtime::NDArray> params_tmp;
355+
for (const auto& kv : params) {
356+
params_tmp[kv.first] = kv.second->data;
357+
}
358+
BindParamsInModule(mod, params_tmp);
359+
}
360+
311361
} // namespace backend
312362
} // namespace relay
313363
} // namespace tvm

src/relay/backend/utils.h

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -386,55 +386,13 @@ inline std::string DType2String(const tvm::DataType dtype) {
386386
* \param params params dict
387387
* \return relay::Function
388388
*/
389-
inline relay::Function BindParamsByName(
390-
relay::Function func, const std::unordered_map<std::string, runtime::NDArray>& params) {
391-
std::unordered_map<std::string, relay::Var> name_dict;
392-
std::unordered_set<relay::Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
393-
for (auto arg : func->params) {
394-
const auto& name = arg->name_hint();
395-
if (name_dict.count(name)) {
396-
repeat_var.insert(name_dict[name]);
397-
} else {
398-
name_dict[name] = arg;
399-
}
400-
}
401-
402-
std::unordered_map<relay::Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
403-
for (auto& kv : params) {
404-
if (name_dict.count(kv.first) == 0) {
405-
continue;
406-
}
407-
auto arg = name_dict.at(kv.first);
408-
if (repeat_var.count(arg)) {
409-
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
410-
}
411-
bind_dict[arg] = Constant(kv.second);
412-
}
413-
Expr bound_expr = relay::Bind(func, bind_dict);
414-
Function ret = Downcast<Function>(bound_expr);
415-
ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
416-
<< "\n";
417-
return ret;
418-
}
389+
relay::Function BindParamsByName(relay::Function func,
390+
const std::unordered_map<std::string, runtime::NDArray>& params);
419391

420-
inline void BindParamsInModule(IRModule mod,
421-
const std::unordered_map<std::string, runtime::NDArray>& params) {
422-
if (!params.empty()) {
423-
BaseFunc base_func = mod->Lookup("main");
424-
ICHECK(base_func->IsInstance<FunctionNode>());
425-
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params);
426-
auto gvar = mod->GetGlobalVar("main");
427-
mod->Add(gvar, f);
428-
}
429-
}
392+
void BindParamsInModule(IRModule mod,
393+
const std::unordered_map<std::string, runtime::NDArray>& params);
430394

431-
inline void BindParamsInModule(IRModule mod, Map<String, Constant> params) {
432-
std::unordered_map<std::string, runtime::NDArray> params_tmp;
433-
for (const auto& kv : params) {
434-
params_tmp[kv.first] = kv.second->data;
435-
}
436-
BindParamsInModule(mod, params_tmp);
437-
}
395+
void BindParamsInModule(IRModule mod, Map<String, Constant> params);
438396

439397
/*!
440398
* \brief Extract the shape from a Relay tensor type.

0 commit comments

Comments
 (0)