Skip to content

Commit 109187f

Browse files
committed
New relay backend for meta schedule task extraction
1 parent 6f01901 commit 109187f

File tree

4 files changed

+124
-1
lines changed

4 files changed

+124
-1
lines changed

python/tvm/meta_schedule/integration.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from contextlib import contextmanager
1919
from typing import Callable, Dict, List, Optional, Union
2020

21-
from tvm._ffi import register_object
21+
import numpy as np
22+
import tvm.runtime.ndarray as nd
23+
24+
from tvm._ffi import register_object, get_global_func
2225
from tvm.ir import IRModule, transform
2326
from tvm.relay import Any
2427
from tvm.relay import Function as RelayFunc
@@ -230,6 +233,20 @@ def extract_task_from_relay(
230233
The tasks extracted from this network
231234
"""
232235

236+
extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask")
237+
assert extract_task_func
238+
239+
target = Target(target) if isinstance(target, str) else target
240+
241+
for name, param in params.items():
242+
if isinstance(param, np.ndarray):
243+
params[name] = nd.array(param)
244+
245+
with transform.PassContext(opt_level=opt_level):
246+
with target:
247+
tasks = extract_task_func(mod, target, params)
248+
return tasks
249+
233250
@contextmanager
234251
def _autotvm_silencer():
235252
from tvm import autotvm # pylint: disable=import-outside-toplevel
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <tvm/meta_schedule/integration.h>
21+
#include <tvm/relay/expr.h>
22+
#include <tvm/relay/expr_functor.h>
23+
#include <tvm/relay/function.h>
24+
#include <tvm/target/target.h>
25+
26+
#include "../../te/operation/create_primfunc.h"
27+
#include "te_compiler_cache.h"
28+
#include "utils.h"
29+
30+
namespace tvm {
31+
namespace relay {
32+
namespace backend {
33+
namespace metaschedule {
34+
35+
using meta_schedule::ExtractedTask;
36+
37+
Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Constant> params) {
38+
if (params.size()) {
39+
std::unordered_map<std::string, runtime::NDArray> params_;
40+
BaseFunc base_func = mod->Lookup("main");
41+
ICHECK(base_func->IsInstance<FunctionNode>());
42+
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
43+
auto gvar = mod->GetGlobalVar("main");
44+
mod->Add(gvar, f);
45+
}
46+
47+
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
48+
pass_seqs.push_back(transform::FuseOps());
49+
50+
transform::Sequential seq(pass_seqs);
51+
auto opt_mod = seq(std::move(mod));
52+
53+
Array<ExtractedTask> tasks;
54+
LOG(INFO) << opt_mod;
55+
LOG(INFO) << opt_mod->Lookup("main");
56+
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) {
57+
if (exp->IsInstance<FunctionNode>()) {
58+
Function relay_func = Downcast<Function>(exp);
59+
if (relay_func->HasNonzeroAttr(attr::kPrimitive)) {
60+
LOG(INFO) << relay_func;
61+
Array<te::Tensor> outputs;
62+
std::string fused_name;
63+
std::tie(outputs, fused_name) = tec::LowerTECompute(target, relay_func);
64+
LOG(INFO) << fused_name;
65+
LOG(INFO) << outputs;
66+
auto prim_func = tir::CreatePrimFunc(outputs);
67+
auto prim_fn_var = GlobalVar(fused_name);
68+
auto relay_mod = IRModule({{prim_fn_var, relay_func}});
69+
auto tir_mod = IRModule({{prim_fn_var, prim_func}});
70+
tasks.push_back(ExtractedTask(prim_fn_var->name_hint, relay_mod, target, {tir_mod}));
71+
}
72+
}
73+
});
74+
75+
return tasks;
76+
}
77+
78+
TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask")
79+
.set_body_typed([](IRModule mod, Target target, Map<String, Constant> params) {
80+
return ExtractTask(mod, target, params);
81+
});
82+
83+
} // namespace metaschedule
84+
} // namespace backend
85+
} // namespace relay
86+
} // namespace tvm

src/relay/backend/te_compiler_cache.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,24 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
754754
return MakeShapeFunc().Create(prim_func, target, renamer);
755755
}
756756

757+
std::pair<Array<te::Tensor>, std::string> LowerTECompute(Target target, const Function& relay_func,
758+
bool return_inputs) {
759+
LowerToTECompute lower_te_compute(target);
760+
auto outputs = lower_te_compute.Lower(relay_func, [&](std::string name) { return name; });
761+
// Following ScheduleBuilder, remove placeholder ops from outputs.
762+
tvm::Array<te::Tensor> tensor_outs;
763+
for (const auto& tensor : outputs) {
764+
if (!tensor->op.as<te::PlaceholderOpNode>()) {
765+
tensor_outs.push_back(tensor);
766+
}
767+
}
768+
if (return_inputs) {
769+
return std::make_pair(Concat(lower_te_compute.fn_inputs_, tensor_outs),
770+
lower_te_compute.candidate_name_);
771+
}
772+
return std::make_pair(tensor_outs, lower_te_compute.candidate_name_);
773+
}
774+
757775
/*!
758776
* \brief Get unique name from name.
759777
* \param name The orginal name.

src/relay/backend/te_compiler_cache.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ class CCacheValue : public ObjectRef {
204204

205205
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape);
206206

207+
std::pair<Array<te::Tensor>, std::string> LowerTECompute(Target target, const Function& relay_func, bool return_inputs=true);
208+
207209
/*!
208210
* \brief Create schedule for target.
209211
* \param source_func The primitive function to be lowered.

0 commit comments

Comments
 (0)