2323
2424from tvm ._ffi import register_object , get_global_func
2525from tvm .ir import IRModule , transform
26- from tvm .relay import Any
26+ from tvm .relay import Any , const
2727from tvm .relay import Function as RelayFunc
2828from tvm .relay import vm
2929from tvm .runtime import NDArray , Object
@@ -238,9 +238,11 @@ def extract_task_from_relay(
238238
239239 target = Target (target ) if isinstance (target , str ) else target
240240
241+ relay_params = {}
241242 for name , param in params .items ():
242243 if isinstance (param , np .ndarray ):
243- params [name ] = nd .array (param )
244+ param = nd .array (param )
245+ relay_params [name ] = const (param )
244246
245247 if disabled_pass is None :
246248 disabled_pass = []
@@ -250,11 +252,10 @@ def extract_task_from_relay(
250252 if not isinstance (target , Target ):
251253 target = Target (target )
252254
253- with transform .PassContext (
255+ with target , transform .PassContext (
254256 opt_level = opt_level ,
255257 config = pass_config ,
256258 disabled_pass = disabled_pass ,
257259 ):
258- with target :
259- tasks = extract_task_func (mod , target , params )
260- return tasks
260+ tasks = extract_task_func (mod , target , relay_params )
261+ return tasks
0 commit comments