Skip to content

Commit 86ba517

Browse files
authored
[Refactor] Expose meta-schedule related packed func in a header and call it directly (#10470)
* [MetaSchedule] Expose CreatePrimFuncFromOutputs in a header and call it directly * add include guard * exposed ContextQueryInsideWithScope too * oops * add tir namespace for clarity * address comment
1 parent 4ae142f commit 86ba517

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

src/relay/backend/te_compiler_cache.cc

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <tvm/driver/driver_api.h>
2323
#include <tvm/ir/type_functor.h>
24+
#include <tvm/meta_schedule/integration.h>
2425
#include <tvm/relay/analysis.h>
2526
#include <tvm/relay/attrs/device_copy.h>
2627
#include <tvm/relay/expr.h>
@@ -41,6 +42,7 @@
4142
#include <utility>
4243
#include <vector>
4344

45+
#include "../../te/operation/create_primfunc.h"
4446
#include "../op/memory/memory.h"
4547
#include "../transforms/pass_utils.h"
4648
#include "utils.h"
@@ -178,16 +180,11 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
178180
}
179181
}
180182
if (use_meta_schedule_) {
181-
const auto* f_create_func = runtime::Registry::Get("te.CreatePrimFuncFromOutputs");
182-
const auto* f_meta_schedule =
183-
runtime::Registry::Get("meta_schedule.MetaScheduleContextQueryInsideWithScope");
184-
ICHECK(f_create_func) << "te.CreatePrimFuncFromOutputs is not registered";
185-
ICHECK(f_meta_schedule)
186-
<< "meta_schedule.MetaScheduleContextQueryInsideWithScope is not registered";
187-
prim_func = (*f_create_func)(tensor_outs);
183+
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
188184
Optional<ObjectRef> opt_mod_or_base_func =
189-
(*f_meta_schedule)(prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}),
190-
target_, Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
185+
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
186+
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
187+
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
191188
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
192189
prim_func = GetRef<tir::PrimFunc>(result);
193190
} else {

src/te/operation/create_primfunc.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf
264264
/*annotations=*/extern_op->attrs));
265265
}
266266

267-
/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
268267
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
269268
// Step 1. Create tensor read graph.
270269
Array<te::Operation> arg_ops;

src/te/operation/create_primfunc.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_TE_OPERATION_CREATE_PRIMFUNC_H_
21+
#define TVM_TE_OPERATION_CREATE_PRIMFUNC_H_
22+
23+
#include <tvm/runtime/container/array.h>
24+
#include <tvm/te/tensor.h>
25+
#include <tvm/tir/function.h>
26+
27+
namespace tvm {
28+
namespace tir {
29+
30+
/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
31+
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list);
32+
33+
/*! \brief Create a schedulable TensorIR func from TE compute outputs. */
34+
PrimFunc CreatePrimFuncFromOutputs(const Array<te::Tensor>& outputs);
35+
36+
} // namespace tir
37+
} // namespace tvm
38+
39+
#endif // TVM_TE_OPERATION_CREATE_PRIMFUNC_H_

0 commit comments

Comments
 (0)