Skip to content

Commit eb1bc7e

Browse files
committed
fixed merge conflict
1 parent 6e68fd9 commit eb1bc7e

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

src/relay/backend/te_compiler_cache.cc

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,7 @@ int LowerToTECompute::const_index = 0;
278278
class ScheduleBuilder : ExprVisitor {
279279
public:
280280
explicit ScheduleBuilder(Target target, bool create_schedule = true)
281-
: target_(target),
282-
283-
create_schedule_(create_schedule) {
281+
: target_(target), create_schedule_(create_schedule) {
284282
// Whether to use auto_scheduler schedule.
285283
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
286284
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
@@ -289,6 +287,7 @@ class ScheduleBuilder : ExprVisitor {
289287
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
290288
LowerToTECompute lower_te_compute(target_);
291289
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
290+
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
292291
std::string candidate_name = lower_te_compute.readable_name_stream_.str();
293292
VisitExpr(relay_func->body);
294293

@@ -332,7 +331,7 @@ class ScheduleBuilder : ExprVisitor {
332331
}
333332
}
334333
if (use_meta_schedule_) {
335-
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
334+
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
336335
Optional<ObjectRef> opt_mod_or_base_func =
337336
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
338337
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
@@ -359,9 +358,8 @@ class ScheduleBuilder : ExprVisitor {
359358
}
360359
}
361360

362-
return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule,
363-
prim_func, {}, IRModule(Map<GlobalVar, BaseFunc>({})),
364-
lower_te_compute.constant_tensors_);
361+
return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
362+
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
365363
}
366364

367365
void VisitExpr_(const CallNode* call_node) final {

0 commit comments

Comments
 (0)