Skip to content

Commit b71f47f

Browse files
committed
[MetaSchedule] Add Gradient Based Task Scheduler
1 parent 642fc57 commit b71f47f

File tree

6 files changed

+560
-4
lines changed

6 files changed

+560
-4
lines changed

include/tvm/meta_schedule/task_scheduler.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ namespace meta_schedule {
6767
*/
6868
class TaskSchedulerNode : public runtime::Object {
6969
public:
70+
/*! \brief The function type of the objective function. */
71+
using FObjectiveFunc = TypedPackedFunc<FloatImm(Array<FloatImm>)>;
72+
/*! \brief The function type of the tag genration function. */
73+
using FTagGenerationFunc = TypedPackedFunc<String(const IRModule&)>;
74+
7075
/*! \brief The tasks to be tuned */
7176
Array<TuneContext> tasks;
7277
/*! \brief The builder of the scheduler. */
@@ -288,6 +293,36 @@ class TaskScheduler : public runtime::ObjectRef {
288293
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
289294
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
290295
PyTaskSchedulerNode::FNextTaskId f_next_task_id);
296+
/*!
297+
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
298+
* \param tasks The tasks to be tuned.
299+
* \param builder The builder of the scheduler.
300+
* \param runner The runner of the scheduler.
301+
* \param database The database of the scheduler.
302+
* \param alpha The parameter alpha to control gradient computation.
303+
* \param beta The parameter beta to control gradient computation.
304+
* \param backward_window_size The parameter to control backward window size.
305+
* \param seed The random seed.
306+
* \param task_weights The weights of each task.
307+
* \param objective_func_name The name of objective function for gradient optimization.
308+
* \param tag_generation_func_name The name of function to generate similarity tag for workloads.
309+
* \param cost_model The cost model of the scheduler.
310+
* \param measure_callbacks The measure callbacks of the scheduler.
311+
* \return The task scheduler created.
312+
*/
313+
TVM_DLL static TaskScheduler GradientBased(Array<TuneContext> tasks, //
314+
Builder builder, //
315+
Runner runner, //
316+
Database database, //
317+
double alpha, //
318+
double beta, //
319+
int backward_window_size, //
320+
support::LinearCongruentialEngine::TRandState seed, //
321+
Array<FloatImm> task_weights, //
322+
String objective_func_name, //
323+
String tag_generation_func_name, //
324+
Optional<CostModel> cost_model, //
325+
Optional<Array<MeasureCallback>> measure_callbacks);
291326
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
292327
};
293328

python/tvm/meta_schedule/task_scheduler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
"""
2323
from .task_scheduler import TaskScheduler, PyTaskScheduler
2424
from .round_robin import RoundRobin
25+
from .gradient_based import GradientBased
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Gradient Based Task Scheduler"""
18+
import math
19+
20+
from typing import TYPE_CHECKING, List, Optional
21+
from tvm._ffi import register_object
22+
from tvm._ffi.registry import register_func
23+
24+
from tvm.ir import IRModule
25+
from ..measure_callback import MeasureCallback
26+
from ..builder import Builder
27+
from ..runner import Runner
28+
from ..database import Database
29+
from ..cost_model import CostModel
30+
from .task_scheduler import TaskScheduler
31+
32+
from .. import _ffi_api
33+
34+
if TYPE_CHECKING:
35+
from ..tune_context import TuneContext
36+
37+
38+
@register_func("meta_schedule.task_scheduler.derive_similarity_tag")
39+
def derive_similarity_tag(mod: IRModule) -> str:
40+
"""Get the tags for smilarity group creation
41+
42+
Parameters
43+
----------
44+
mod : IRModule
45+
The input workload.
46+
47+
Return
48+
------
49+
tag : str
50+
The generated similarity tag.
51+
"""
52+
ret = ""
53+
for var in mod.get_global_vars():
54+
if "meta_scheduler_task_scheduler_tag" in mod[var].attrs:
55+
ret += mod[var].attrs.meta_scheduler_task_scheduler_tag + "_"
56+
if ret:
57+
flop_count = _ffi_api.TaskSchedulerFlopCount(mod) # type: ignore # pylint: disable=no-member
58+
ret += "%d" % int(math.log(flop_count + 1, 1.618))
59+
return ret
60+
61+
62+
@register_object("meta_schedule.GradientBased")
63+
class GradientBased(TaskScheduler):
64+
"""Gradient Based Task Scheduler"""
65+
66+
def __init__(
67+
self,
68+
tasks: List["TuneContext"],
69+
builder: Builder,
70+
runner: Runner,
71+
database: Database,
72+
*,
73+
alpha: float = 0.2,
74+
beta: float = 2.0,
75+
backward_window_size: int = 3,
76+
seed: int = -1,
77+
task_weights: List[float] = None,
78+
objective_func_name: str = "meta_schedule.task_scheduler.objective_func",
79+
tag_generation_func_name: str = "meta_schedule.task_scheduler.derive_similarity_tag",
80+
cost_model: Optional[CostModel] = None,
81+
measure_callbacks: Optional[List[MeasureCallback]] = None,
82+
) -> None:
83+
"""Constructor.
84+
85+
Parameters
86+
----------
87+
tasks : List[TuneContext]
88+
List of tasks to schedule.
89+
builder : Builder
90+
The builder.
91+
runner : Runner
92+
The runner.
93+
database : Database
94+
The database.
95+
alpha : float, default 0.2.
96+
The parameter alpha to control gradient computation.
97+
beta : float, default 2.0.
98+
The parameter beta to control gradient computation.
99+
backward_window_size : int, default 3.
100+
The parameter to control backward window size.
101+
seed : int, default -1, meaning random seed.
102+
The random seed.
103+
task_weights : Optional[List[float]], default None, meaning equal weight.
104+
The weights of each task.
105+
objective_func_name : str, default "meta_schedule.task_scheduler.objective_func"
106+
The name of objective function for gradient optimization.
107+
tag_generation_func_name : str,
108+
default "meta_schedule.task_scheduler.derive_similarity_tag"
109+
The name of function to generate similarity tag for workloads.
110+
cost_model : CostModel, default None.
111+
The cost model of the scheduler.
112+
measure_callbacks : Optional[List[MeasureCallback]], default None.
113+
The list of measure callbacks of the scheduler.
114+
"""
115+
116+
@register_func("meta_schedule.task_scheduler.objective_func")
117+
def weighted_sum(latency: List[float]) -> float: # pylint: disable= unused-variable,
118+
"""Get the weighted sum as objective function
119+
120+
Parameters
121+
----------
122+
latency : List[float]
123+
The current latency of each workload.
124+
125+
Returns
126+
-------
127+
result : float
128+
The computed objective function value.
129+
"""
130+
return sum([latency[i] * w for i, w in enumerate(self.task_weights)])
131+
132+
if task_weights is None:
133+
task_weights = [1.0 for _ in tasks]
134+
self.task_weights = task_weights
135+
136+
assert len(task_weights) == len(
137+
tasks
138+
), "The given task weights should be same length as tasks."
139+
140+
self.__init_handle_by_constructor__(
141+
_ffi_api.TaskSchedulerGradientBased, # type: ignore # pylint: disable=no-member
142+
tasks,
143+
builder,
144+
runner,
145+
database,
146+
alpha,
147+
beta,
148+
backward_window_size,
149+
seed,
150+
task_weights,
151+
objective_func_name,
152+
tag_generation_func_name,
153+
cost_model,
154+
measure_callbacks,
155+
)

python/tvm/meta_schedule/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class _PyRunner(meta_schedule.Runner):
5353
def __init__(self, f_run: Callable = None):
5454
self.__init_handle_by_constructor__(_ffi_api.RunnerPyRunner, f_run)
5555
56-
class PyRunner():
56+
class PyRunner:
5757
_tvm_metadata = {
5858
"cls": _PyRunner,
5959
"methods": ["run"]

0 commit comments

Comments
 (0)