@@ -278,9 +278,7 @@ int LowerToTECompute::const_index = 0;
278278class ScheduleBuilder : ExprVisitor {
279279 public:
280280 explicit ScheduleBuilder (Target target, bool create_schedule = true )
281- : target_(target),
282-
283- create_schedule_(create_schedule) {
281+ : target_(target), create_schedule_(create_schedule) {
284282 // Whether to use auto_scheduler schedule.
285283 use_auto_scheduler_ = backend::IsAutoSchedulerEnabled ();
286284 use_meta_schedule_ = backend::IsMetaScheduleEnabled ();
@@ -289,6 +287,7 @@ class ScheduleBuilder : ExprVisitor {
289287 CachedFunc Create (const Function& relay_func, std::function<std::string(std::string)> renamer) {
290288 LowerToTECompute lower_te_compute (target_);
291289 Array<te::Tensor> outputs = lower_te_compute.Lower (relay_func, renamer);
290+ Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_ ;
292291 std::string candidate_name = lower_te_compute.readable_name_stream_ .str ();
293292 VisitExpr (relay_func->body );
294293
@@ -332,7 +331,7 @@ class ScheduleBuilder : ExprVisitor {
332331 }
333332 }
334333 if (use_meta_schedule_) {
335- prim_func = tir::CreatePrimFuncFromOutputs ( tensor_outs);
334+ prim_func = tir::CreatePrimFunc ( Concat (fn_inputs, tensor_outs) );
336335 Optional<ObjectRef> opt_mod_or_base_func =
337336 meta_schedule::MetaScheduleContext::QueryInsideWithScope (
338337 prim_fn_var->name_hint , IRModule ({{prim_fn_var, relay_func}}), target_,
@@ -359,9 +358,8 @@ class ScheduleBuilder : ExprVisitor {
359358 }
360359 }
361360
362- return CachedFunc (target_, prim_fn_var, lower_te_compute.fn_inputs_ , outputs, schedule,
363- prim_func, {}, IRModule (Map<GlobalVar, BaseFunc>({})),
364- lower_te_compute.constant_tensors_ );
361+ return CachedFunc (target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
362+ IRModule (Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_ );
365363 }
366364
367365 void VisitExpr_ (const CallNode* call_node) final {
0 commit comments