diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 34ce3f24da27..02da4f999513 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -41,6 +42,7 @@ #include #include +#include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" #include "utils.h" @@ -178,16 +180,11 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } } if (use_meta_schedule_) { - const auto* f_create_func = runtime::Registry::Get("te.CreatePrimFuncFromOutputs"); - const auto* f_meta_schedule = - runtime::Registry::Get("meta_schedule.MetaScheduleContextQueryInsideWithScope"); - ICHECK(f_create_func) << "te.CreatePrimFuncFromOutputs is not registered"; - ICHECK(f_meta_schedule) - << "meta_schedule.MetaScheduleContextQueryInsideWithScope is not registered"; - prim_func = (*f_create_func)(tensor_outs); + prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs); Optional opt_mod_or_base_func = - (*f_meta_schedule)(prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), - target_, Array{IRModule({{prim_fn_var, prim_func}})}); + meta_schedule::MetaScheduleContext::QueryInsideWithScope( + prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, + Array{IRModule({{prim_fn_var, prim_func}})}); if (const auto* result = opt_mod_or_base_func.as()) { prim_func = GetRef(result); } else { diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 94dd0b044d71..4e160605f523 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -264,7 +264,6 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf /*annotations=*/extern_op->attrs)); } -/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const Array& arg_list) { // Step 1. Create tensor read graph. Array arg_ops; diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h new file mode 100644 index 000000000000..d911e5ebcdb7 --- /dev/null +++ b/src/te/operation/create_primfunc.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_TE_OPERATION_CREATE_PRIMFUNC_H_ +#define TVM_TE_OPERATION_CREATE_PRIMFUNC_H_ + +#include +#include +#include + +namespace tvm { +namespace tir { + +/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ +PrimFunc CreatePrimFunc(const Array& arg_list); + +/*! \brief Create a schedulable TensorIR func from TE compute outputs. */ +PrimFunc CreatePrimFuncFromOutputs(const Array& outputs); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TE_OPERATION_CREATE_PRIMFUNC_H_