@@ -386,55 +386,13 @@ inline std::string DType2String(const tvm::DataType dtype) {
386386 * \param params params dict
387387 * \return relay::Function
388388 */
389- inline relay::Function BindParamsByName (
390- relay::Function func, const std::unordered_map<std::string, runtime::NDArray>& params) {
391- std::unordered_map<std::string, relay::Var> name_dict;
392- std::unordered_set<relay::Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
393- for (auto arg : func->params ) {
394- const auto & name = arg->name_hint ();
395- if (name_dict.count (name)) {
396- repeat_var.insert (name_dict[name]);
397- } else {
398- name_dict[name] = arg;
399- }
400- }
401-
402- std::unordered_map<relay::Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
403- for (auto & kv : params) {
404- if (name_dict.count (kv.first ) == 0 ) {
405- continue ;
406- }
407- auto arg = name_dict.at (kv.first );
408- if (repeat_var.count (arg)) {
409- LOG (FATAL) << " Multiple args in the function have name " << kv.first ;
410- }
411- bind_dict[arg] = Constant (kv.second );
412- }
413- Expr bound_expr = relay::Bind (func, bind_dict);
414- Function ret = Downcast<Function>(bound_expr);
415- ICHECK (ret.defined ()) << " The returning type is expected to be a Relay Function."
416- << " \n " ;
417- return ret;
418- }
389+ relay::Function BindParamsByName (relay::Function func,
390+ const std::unordered_map<std::string, runtime::NDArray>& params);
419391
420- inline void BindParamsInModule (IRModule mod,
421- const std::unordered_map<std::string, runtime::NDArray>& params) {
422- if (!params.empty ()) {
423- BaseFunc base_func = mod->Lookup (" main" );
424- ICHECK (base_func->IsInstance <FunctionNode>());
425- auto f = relay::backend::BindParamsByName (Downcast<Function>(base_func), params);
426- auto gvar = mod->GetGlobalVar (" main" );
427- mod->Add (gvar, f);
428- }
429- }
392+ void BindParamsInModule (IRModule mod,
393+ const std::unordered_map<std::string, runtime::NDArray>& params);
430394
431- inline void BindParamsInModule (IRModule mod, Map<String, Constant> params) {
432- std::unordered_map<std::string, runtime::NDArray> params_tmp;
433- for (const auto & kv : params) {
434- params_tmp[kv.first ] = kv.second ->data ;
435- }
436- BindParamsInModule (mod, params_tmp);
437- }
395+ void BindParamsInModule (IRModule mod, Map<String, Constant> params);
438396
439397/* !
440398 * \brief Extract the shape from a Relay tensor type.
0 commit comments