Skip to content

Commit ce335c3

Browse files
masahijunrushao
andauthored
[Metaschedule] New relay backend for meta schedule task extraction (#10578)
* New relay backend for meta schedule task extraction commit 501fac6 Merge: 076fa33 ce8c563 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 14:16:47 2022 +0900 New relay backend for meta schedule task extraction commit ce8c563 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 14:12:30 2022 +0900 fix cpplint commit dfa4fb0 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 14:09:11 2022 +0900 update expected op list in test_meta_schedule_integration_extract_from_resnet to remove dep on Ansor commit a98182e Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 13:56:35 2022 +0900 fixed test_meta_schedule_integration_apply_history_best commit 40d52a1 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 13:50:43 2022 +0900 uniquefy task names commit dfaf496 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 13:45:30 2022 +0900 dedup tasks commit e49d500 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 12:59:45 2022 +0900 return reversed list commit 74636be Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 12:39:58 2022 +0900 refactor commit 99f1701 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 12:34:14 2022 +0900 clean up integration.cc and Query interface commit 3f93a1e Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 11:54:57 2022 +0900 check in minor vnni-related change commit af3e988 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 07:36:35 2022 +0900 Removed TaskExtraction node commit 7b4d35e Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 05:42:56 2022 +0900 add doc to util functions commit 3c5a318 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 05:27:53 2022 +0900 rename to task extraction commit 57f2882 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 05:24:37 2022 +0900 fixed constant param bind commit f099537 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 05:10:44 2022 +0900 remove unused stuff from python extract_tasks_from_relay commit 4a5e4aa Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 05:10:30 2022 +0900 move BindParams function to cc file commit efeccea Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 03:56:05 2022 +0900 refactor param binding commit 109187f Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 02:21:58 2022 +0900 New relay backend for meta schedule task extraction commit 6f01901 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 11:25:44 2022 +0900 fixed anchor impl selection commit be6c258 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 10:57:02 2022 +0900 Forgot visiting arg in ScheduleBuilder CallNode vsit commit 0c6d4a6 Author: Masahiro Masuda <[email protected]> Date: Fri Mar 11 10:45:08 2022 +0900 add public, fix include path convention commit 4cd3a16 Author: Masahiro Masuda <[email protected]> Date: Thu Mar 10 18:43:15 2022 +0900 removed create_schedule stuff commit eb1bc7e Author: Masahiro Masuda <[email protected]> Date: Thu Mar 10 18:13:42 2022 +0900 fixed merge conflict commit 6e68fd9 Author: Masahiro Masuda <[email protected]> Date: Thu Mar 10 14:27:34 2022 +0900 Decouple TE compute and schedule lowering in ScheduleBuilder * update integration.h doc * remove unused import * fix mypy check * use_meta_schedule restored, now extracts the same task as Ansor * python doc update * unused import * cache_ -> cache, suppres "Cannot find workdload" warning * Update src/relay/backend/task_extraction.cc and te_compiler_cache.cc Co-authored-by: Junru Shao <[email protected]> * removed unnecessary include * fixed build * drop relay.const on params * updated comment in integration.cc * update schedule_rule name to prepend "metaschedule" * typo fix * more nit change * make the output of Query Optional * update py doc * remove TODO comment on parse_mod Co-authored-by: Junru Shao <[email protected]>
1 parent ab4289d commit ce335c3

File tree

13 files changed

+264
-245
lines changed

13 files changed

+264
-245
lines changed

include/tvm/meta_schedule/integration.h

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,12 @@ class MetaScheduleContextNode : public runtime::Object {
8686
* \param target Target info
8787
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to.
8888
* NullOpt means the dispatch needs to be done in the context.
89-
* \return There are different types of the output
90-
* 1) NullOpt if there is no feedback hint
91-
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
92-
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
93-
* 4) IRModule for unified dispatch
89+
* \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it
90+
* under IRModule for more general future use. NullOpt is returned
91+
* if there is no feedback hint.
9492
*/
95-
virtual Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
96-
Optional<Array<IRModule>> dispatched) = 0;
93+
virtual Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
94+
Optional<Array<IRModule>> dispatched) = 0;
9795

9896
static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext";
9997
TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object);
@@ -123,15 +121,13 @@ class MetaScheduleContext : public runtime::ObjectRef {
123121
* \param mod The high-level IR
124122
* \param target Target info
125123
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to
126-
* \return There are different types of the output
127-
* 1) NullOpt if there is no feedback hint
128-
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
129-
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
130-
* 4) IRModule for unified dispatch
124+
* \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it
125+
* under IRModule for more general future use. NullOpt is returned
126+
* if there is no feedback hint
131127
*/
132-
static Optional<ObjectRef> QueryInsideWithScope(runtime::String task_name, IRModule mod,
133-
Target target,
134-
Optional<Array<IRModule>> dispatched);
128+
static Optional<IRModule> QueryInsideWithScope(runtime::String task_name, IRModule mod,
129+
Target target,
130+
Optional<Array<IRModule>> dispatched);
135131

136132
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef,
137133
MetaScheduleContextNode);
@@ -145,38 +141,6 @@ class MetaScheduleContext : public runtime::ObjectRef {
145141
void ExitWithScope();
146142
};
147143

148-
/**************** TaskExtraction ****************/
149-
150-
/*!
151-
* \brief An integration context for task extraction
152-
*/
153-
class TaskExtractionNode : public MetaScheduleContextNode {
154-
public:
155-
/*! \brief The extracted tasks */
156-
Array<ExtractedTask> tasks{nullptr};
157-
158-
void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); }
159-
160-
// Inherited from base class
161-
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
162-
Optional<Array<IRModule>> dispatched) final;
163-
164-
static constexpr const char* _type_key = "meta_schedule.TaskExtraction";
165-
TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, MetaScheduleContextNode);
166-
};
167-
168-
/*!
169-
* \brief Managed reference to TaskExtractionNode
170-
* \sa TaskExtractionNode
171-
*/
172-
class TaskExtraction : public MetaScheduleContext {
173-
public:
174-
/*! \brief The path to a cache file storing extracted tasks */
175-
TaskExtraction();
176-
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, MetaScheduleContext,
177-
TaskExtractionNode);
178-
};
179-
180144
/**************** ApplyHistoryBest ****************/
181145

182146
/*!
@@ -193,8 +157,8 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode {
193157
}
194158

195159
// Inherited from base class
196-
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
197-
Optional<Array<IRModule>> dispatched) final;
160+
Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
161+
Optional<Array<IRModule>> dispatched) final;
198162

199163
static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
200164
TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode);

python/tvm/meta_schedule/integration.py

Lines changed: 30 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Meta schedule integration with high-level IR"""
18-
from contextlib import contextmanager
19-
from typing import Callable, Dict, List, Optional, Union
18+
from typing import Dict, List, Optional, Union
2019

21-
from tvm._ffi import register_object
20+
import numpy as np # type: ignore
21+
import tvm.runtime.ndarray as nd
22+
23+
from tvm._ffi import register_object, get_global_func
2224
from tvm.ir import IRModule, transform
2325
from tvm.relay import Any
2426
from tvm.relay import Function as RelayFunc
25-
from tvm.relay import vm
2627
from tvm.runtime import NDArray, Object
2728
from tvm.target import Target
28-
from tvm.tir import PrimFunc
2929

3030
from . import _ffi_api
3131
from .database import Database
@@ -77,7 +77,7 @@ def query(
7777
mod: IRModule,
7878
target: Target,
7979
dispatched: Optional[List[IRModule]],
80-
) -> Union[IRModule, RelayFunc, PrimFunc, None]:
80+
) -> Union[IRModule, None]:
8181
"""The entry point of the integration
8282
8383
Parameters
@@ -93,12 +93,9 @@ def query(
9393
9494
Returns
9595
-------
96-
result : Union[IRModule, RelayFunc, PrimFunc, None]
97-
There are different types of the output:
98-
1) NullOpt if there is no feedback hint;
99-
2) tir::PrimFunc if `mod` should be lowered to a PrimFunc;
100-
3) relay::Function if `mod` should be dispatched to BYOC workflow;
101-
4) IRModule for unified dispatch
96+
result : IRModule or None
97+
Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
98+
more general future use. None is returned if there is no feedback hint.
10299
"""
103100
return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member
104101
self,
@@ -126,7 +123,7 @@ def query_inside_with_scope(
126123
mod: IRModule,
127124
target: Target,
128125
dispatched: Optional[List[IRModule]],
129-
) -> Union[IRModule, RelayFunc, PrimFunc, None]:
126+
) -> Union[IRModule, None]:
130127
"""The entry point of the integration workflow. The compilation process of the high-level
131128
IR should call this method for task extraction and for feedback hints
132129
@@ -137,7 +134,7 @@ def query_inside_with_scope(
137134
def query_inside_with_scope(task_name, mod, dispatched):
138135
ctx = MetaScheduleContext.current()
139136
assert ctx is not None
140-
ctx.query(task_name, mod, target, dispatched)
137+
mod = ctx.query(task_name, mod, target, dispatched)
141138
142139
Parameters
143140
----------
@@ -152,12 +149,9 @@ def query_inside_with_scope(task_name, mod, dispatched):
152149
153150
Returns
154151
-------
155-
result : Union[IRModule, RelayFunc, PrimFunc, None]
156-
There are different types of the output:
157-
1) NullOpt if there is no feedback hint;
158-
2) tir::PrimFunc if `mod` should be lowered to a PrimFunc;
159-
3) relay::Function if `mod` should be dispatched to BYOC workflow;
160-
4) IRModule for unified dispatch
152+
result : IRModule or None
153+
Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
154+
more general future use. None is returned if there is no feedback hint.
161155
"""
162156
return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member
163157
task_name,
@@ -176,17 +170,6 @@ def __exit__(self, ptype, value, trace) -> None:
176170
_ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member
177171

178172

179-
@register_object("meta_schedule.TaskExtraction")
180-
class TaskExtraction(MetaScheduleContext):
181-
"""An integration context for task extraction"""
182-
183-
tasks: List[ExtractedTask]
184-
"""The extracted tasks"""
185-
186-
def __init__(self) -> None:
187-
self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member
188-
189-
190173
@register_object("meta_schedule.ApplyHistoryBest")
191174
class ApplyHistoryBest(MetaScheduleContext):
192175
"""An integration context that allows application of historically best record from database"""
@@ -230,45 +213,32 @@ def extract_task_from_relay(
230213
The tasks extracted from this network
231214
"""
232215

233-
@contextmanager
234-
def _autotvm_silencer():
235-
from tvm import autotvm # pylint: disable=import-outside-toplevel
236-
237-
silent = autotvm.GLOBAL_SCOPE.silent
238-
autotvm.GLOBAL_SCOPE.silent = True
239-
try:
240-
yield
241-
finally:
242-
autotvm.GLOBAL_SCOPE.silent = silent
216+
extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask")
217+
assert extract_task_func
243218

244-
def _thread_run(func: Callable[[], None]) -> None:
245-
import threading # pylint: disable=import-outside-toplevel
219+
target = Target(target) if isinstance(target, str) else target
246220

247-
thread = threading.Thread(target=func)
248-
thread.start()
249-
thread.join()
221+
relay_params = {}
222+
for name, param in params.items():
223+
if isinstance(param, np.ndarray):
224+
param = nd.array(param)
225+
relay_params[name] = param
250226

251227
if disabled_pass is None:
252228
disabled_pass = []
253229
if pass_config is None:
254230
pass_config = {"relay.backend.use_meta_schedule": True}
255231

256-
env = TaskExtraction()
257232
if isinstance(mod, RelayFunc):
258233
mod = IRModule.from_expr(mod)
259234
if not isinstance(target, Target):
260235
target = Target(target)
261236

262-
def _func():
263-
with env, _autotvm_silencer(), transform.PassContext(
264-
config=pass_config,
265-
disabled_pass=disabled_pass,
266-
opt_level=opt_level,
267-
):
268-
compiler = vm.VMCompiler()
269-
if params:
270-
compiler.set_params(params)
271-
compiler.lower(mod, target)
272-
273-
_thread_run(_func)
274-
return env.tasks
237+
with target, transform.PassContext(
238+
opt_level=opt_level,
239+
config=pass_config,
240+
disabled_pass=disabled_pass,
241+
):
242+
tasks = extract_task_func(mod, target, relay_params)
243+
# Tasks are extracted via post order visit, return the reversed list.
244+
return list(reversed(tasks))

python/tvm/topi/x86/batch_matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def batch_matmul_vnni_compute(cfg, x, y):
4747
axis=ak,
4848
),
4949
tag="batch_matmul_vnni",
50+
attrs={"schedule_rule": "meta_schedule.batch_matmul_vnni"},
5051
)
5152

5253
_, a_y, _ = z.op.axis

python/tvm/topi/x86/dense.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None):
296296
axis=ak,
297297
),
298298
tag="dense_vnni",
299+
attrs={"schedule_rule": "meta_schedule.dense_vnni"},
299300
)
300301

301302
if bias is not None:

0 commit comments

Comments
 (0)