Skip to content

Commit 57f2882

Browse files
committed
fixed constant param bind
1 parent f099537 commit 57f2882

File tree

2 files changed

+8
-15
lines changed

2 files changed

+8
-15
lines changed

python/tvm/meta_schedule/integration.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from tvm._ffi import register_object, get_global_func
2525
from tvm.ir import IRModule, transform
26-
from tvm.relay import Any
26+
from tvm.relay import Any, const
2727
from tvm.relay import Function as RelayFunc
2828
from tvm.relay import vm
2929
from tvm.runtime import NDArray, Object
@@ -238,9 +238,11 @@ def extract_task_from_relay(
238238

239239
target = Target(target) if isinstance(target, str) else target
240240

241+
relay_params = {}
241242
for name, param in params.items():
242243
if isinstance(param, np.ndarray):
243-
params[name] = nd.array(param)
244+
param = nd.array(param)
245+
relay_params[name] = const(param)
244246

245247
if disabled_pass is None:
246248
disabled_pass = []
@@ -250,11 +252,10 @@ def extract_task_from_relay(
250252
if not isinstance(target, Target):
251253
target = Target(target)
252254

253-
with transform.PassContext(
255+
with target, transform.PassContext(
254256
opt_level=opt_level,
255257
config=pass_config,
256258
disabled_pass=disabled_pass,
257259
):
258-
with target:
259-
tasks = extract_task_func(mod, target, params)
260-
return tasks
260+
tasks = extract_task_func(mod, target, relay_params)
261+
return tasks

src/relay/backend/metaschedule_task_extraction.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,7 @@ namespace metaschedule {
3535
using meta_schedule::ExtractedTask;
3636

3737
Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Constant> params) {
38-
// backend::BindParamsInModule(mod, params);
39-
if (params.size()) {
40-
std::unordered_map<std::string, runtime::NDArray> params_;
41-
BaseFunc base_func = mod->Lookup("main");
42-
ICHECK(base_func->IsInstance<FunctionNode>());
43-
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
44-
auto gvar = mod->GetGlobalVar("main");
45-
mod->Add(gvar, f);
46-
}
38+
backend::BindParamsInModule(mod, params);
4739

4840
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
4941
pass_seqs.push_back(transform::FuseOps());

0 commit comments

Comments
 (0)