Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/meta_schedule/integration.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ class ApplyHistoryBest : public MetaScheduleContext {
ApplyHistoryBestNode);
};

Optional<ObjectRef> ContextQueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched);
} // namespace meta_schedule
} // namespace tvm

Expand Down
6 changes: 6 additions & 0 deletions src/meta_schedule/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
return NullOpt;
}

Optional<ObjectRef> ContextQueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched) {
return MetaScheduleContext::QueryInsideWithScope(task_name, mod, target, dispatched);
}

/**************** FFI ****************/

class MetaScheduleContextInternal {
Expand Down
16 changes: 6 additions & 10 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
#include <tvm/meta_schedule/integration.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
Expand All @@ -41,6 +42,7 @@
#include <utility>
#include <vector>

#include "../../te/operation/create_primfunc.h"
#include "../op/memory/memory.h"
#include "../transforms/pass_utils.h"
#include "utils.h"
Expand Down Expand Up @@ -178,16 +180,10 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
}
}
if (use_meta_schedule_) {
const auto* f_create_func = runtime::Registry::Get("te.CreatePrimFuncFromOutputs");
const auto* f_meta_schedule =
runtime::Registry::Get("meta_schedule.MetaScheduleContextQueryInsideWithScope");
ICHECK(f_create_func) << "te.CreatePrimFuncFromOutputs is not registered";
ICHECK(f_meta_schedule)
<< "meta_schedule.MetaScheduleContextQueryInsideWithScope is not registered";
prim_func = (*f_create_func)(tensor_outs);
Optional<ObjectRef> opt_mod_or_base_func =
(*f_meta_schedule)(prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}),
target_, Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
Optional<ObjectRef> opt_mod_or_base_func = meta_schedule::ContextQueryInsideWithScope(
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
prim_func = GetRef<tir::PrimFunc>(result);
} else {
Expand Down
35 changes: 35 additions & 0 deletions src/te/operation/create_primfunc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_TE_OPERATION_CREATE_PRIMFUNC_H_
#define TVM_TE_OPERATION_CREATE_PRIMFUNC_H_

#include <tvm/runtime/container/array.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/function.h>

namespace tvm {
namespace tir {

PrimFunc CreatePrimFuncFromOutputs(const Array<te::Tensor>& outputs);

} // namespace tir
} // namespace tvm

#endif // TVM_TE_OPERATION_CREATE_PRIMFUNC_H_