Skip to content

Commit 540fac7

Browse files
committed
Refactor
1 parent b71f47f commit 540fac7

32 files changed

+708
-695
lines changed

include/tvm/meta_schedule/search_strategy.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,21 +252,21 @@ class SearchStrategy : public runtime::ObjectRef {
252252
/*!
253253
* \brief Constructor of replay trace search strategy.
254254
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
255-
* \param num_trials_total The total number of trials for trace replaying.
255+
* \param max_trials_per_task The total number of trials for trace replaying.
256256
*/
257-
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total);
257+
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task);
258258

259259
/*!
260260
* \brief Constructor of replay func search strategy.
261261
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
262-
* \param num_trials_total The total number of trials for func replaying.
262+
* \param max_trials_per_task The total number of trials for func replaying.
263263
*/
264-
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);
264+
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int max_trials_per_task);
265265

266266
/*!
267267
* \brief Constructor of evolutionary search strategy.
268268
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
269-
* \param num_trials_total The total number of trials for evolutionary search.
269+
* \param max_trials_per_task The total number of trials for evolutionary search.
270270
* \param population_size The initial sample population.
271271
* \param init_measured_ratio The ratio of measures samples in initial population.
272272
* \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling.
@@ -276,7 +276,7 @@ class SearchStrategy : public runtime::ObjectRef {
276276
* \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
277277
*/
278278
TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
279-
int num_trials_total, //
279+
int max_trials_per_task, //
280280
int population_size, //
281281
double init_measured_ratio, //
282282
int init_min_unmeasured, //

include/tvm/meta_schedule/task_scheduler.h

Lines changed: 52 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,6 @@ namespace meta_schedule {
6767
*/
6868
class TaskSchedulerNode : public runtime::Object {
6969
public:
70-
/*! \brief The function type of the objective function. */
71-
using FObjectiveFunc = TypedPackedFunc<FloatImm(Array<FloatImm>)>;
72-
/*! \brief The function type of the tag genration function. */
73-
using FTagGenerationFunc = TypedPackedFunc<String(const IRModule&)>;
74-
7570
/*! \brief The tasks to be tuned */
7671
Array<TuneContext> tasks;
7772
/*! \brief The builder of the scheduler. */
@@ -80,10 +75,14 @@ class TaskSchedulerNode : public runtime::Object {
8075
Runner runner{nullptr};
8176
/*! \brief The database of the scheduler. */
8277
Database database{nullptr};
78+
/*! \brief The maximum number of trials allowed. */
79+
int max_trials;
8380
/*! \brief The cost model of the scheduler. */
8481
Optional<CostModel> cost_model;
8582
/*! \brief The list of measure callbacks of the scheduler. */
8683
Array<MeasureCallback> measure_callbacks;
84+
/*! \brief The number of trials already conducted. */
85+
int num_trials_already;
8786

8887
/*! \brief The default destructor. */
8988
virtual ~TaskSchedulerNode() = default;
@@ -93,8 +92,10 @@ class TaskSchedulerNode : public runtime::Object {
9392
v->Visit("builder", &builder);
9493
v->Visit("runner", &runner);
9594
v->Visit("database", &database);
95+
v->Visit("max_trials", &max_trials);
9696
v->Visit("cost_model", &cost_model);
9797
v->Visit("measure_callbacks", &measure_callbacks);
98+
v->Visit("num_trials_already", &num_trials_already);
9899
}
99100

100101
/*! \brief Auto-tuning. */
@@ -107,23 +108,16 @@ class TaskSchedulerNode : public runtime::Object {
107108
virtual void InitializeTask(int task_id);
108109

109110
/*!
110-
* \brief Set specific task to be stopped.
111-
* \param task_id The task id to be stopped.
112-
*/
113-
virtual void SetTaskStopped(int task_id);
114-
115-
/*!
116-
* \brief Check whether the task is running.
111+
* \brief Touch the task and update its status
117112
* \param task_id The task id to be checked.
118-
* \return Whether the task is running.
119113
*/
120-
virtual bool IsTaskRunning(int task_id);
114+
virtual void TouchTask(int task_id);
121115

122116
/*!
123117
* \brief Wait until the task is finished.
124118
* \param task_id The task id to be joined.
125119
*/
126-
virtual void JoinRunningTask(int task_id);
120+
virtual Array<RunnerResult> JoinRunningTask(int task_id);
127121

128122
/*!
129123
* \brief Fetch the next task id.
@@ -147,23 +141,17 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
147141
using FInitializeTask = runtime::TypedPackedFunc<void(int)>;
148142

149143
/*!
150-
* \brief The function type of `SetTaskStopped` method.
151-
* \param task_id The task id to be stopped.
152-
*/
153-
using FSetTaskStopped = runtime::TypedPackedFunc<void(int)>;
154-
155-
/*!
156-
* \brief The function type of `IsTaskRunning` method.
144+
* \brief The function type of `TouchTask` method.
157145
* \param task_id The task id to be checked.
158146
* \return Whether the task is running.
159147
*/
160-
using FIsTaskRunning = runtime::TypedPackedFunc<bool(int)>;
148+
using FTouchTask = runtime::TypedPackedFunc<void(int)>;
161149

162150
/*!
163151
* \brief The function type of `JoinRunningTask` method.
164152
* \param task_id The task id to be joined.
165153
*/
166-
using FJoinRunningTask = runtime::TypedPackedFunc<void(int)>;
154+
using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult>(int)>;
167155

168156
/*!
169157
* \brief The function type of `NextTaskId` method.
@@ -175,10 +163,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
175163
FTune f_tune;
176164
/*! \brief The packed function to the `InitializeTask` function. */
177165
FInitializeTask f_initialize_task;
178-
/*! \brief The packed function to the `SetTaskStopped` function. */
179-
FSetTaskStopped f_set_task_stopped;
180-
/*! \brief The packed function to the `IsTaskRunning` function. */
181-
FIsTaskRunning f_is_task_running;
166+
/*! \brief The packed function to the `TouchTask` function. */
167+
FTouchTask f_touch_task;
182168
/*! \brief The packed function to the `JoinRunningTask` function. */
183169
FJoinRunningTask f_join_running_task;
184170
/*! \brief The packed function to the `NextTaskId` function. */
@@ -187,8 +173,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
187173
void VisitAttrs(tvm::AttrVisitor* v) {
188174
// `f_tune` is not visited
189175
// `f_initialize_task` is not visited
190-
// `f_set_task_stopped` is not visited
191-
// `f_is_task_running` is not visited
176+
// `f_touch_task` is not visited
192177
// `f_join_running_task` is not visited
193178
// `f_next_task_id` is not visited
194179
}
@@ -209,23 +194,15 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
209194
}
210195
}
211196

212-
void SetTaskStopped(int task_id) final {
213-
if (f_set_task_stopped == nullptr) {
214-
TaskSchedulerNode::SetTaskStopped(task_id);
215-
} else {
216-
f_set_task_stopped(task_id);
217-
}
218-
}
219-
220-
bool IsTaskRunning(int task_id) final {
221-
if (f_is_task_running == nullptr) {
222-
return TaskSchedulerNode::IsTaskRunning(task_id);
197+
void TouchTask(int task_id) final {
198+
if (f_touch_task == nullptr) {
199+
return TaskSchedulerNode::TouchTask(task_id);
223200
} else {
224-
return f_is_task_running(task_id);
201+
return f_touch_task(task_id);
225202
}
226203
}
227204

228-
void JoinRunningTask(int task_id) final {
205+
Array<RunnerResult> JoinRunningTask(int task_id) final {
229206
if (f_join_running_task == nullptr) {
230207
return TaskSchedulerNode::JoinRunningTask(task_id);
231208
} else {
@@ -254,6 +231,7 @@ class TaskScheduler : public runtime::ObjectRef {
254231
* \param builder The builder of the scheduler.
255232
* \param runner The runner of the scheduler.
256233
* \param database The database of the scheduler.
234+
* \param max_trials The maximum number of trials.
257235
* \param cost_model The cost model of the scheduler.
258236
* \param measure_callbacks The measure callbacks of the scheduler.
259237
* \return The task scheduler created.
@@ -262,20 +240,47 @@ class TaskScheduler : public runtime::ObjectRef {
262240
Builder builder, //
263241
Runner runner, //
264242
Database database, //
243+
int max_trials, //
265244
Optional<CostModel> cost_model, //
266245
Optional<Array<MeasureCallback>> measure_callbacks);
246+
/*!
247+
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
248+
* \param tasks The tasks to be tuned.
249+
* \param task_weights The weights of each task.
250+
* \param builder The builder of the scheduler.
251+
* \param runner The runner of the scheduler.
252+
* \param database The database of the scheduler.
253+
* \param max_trials The maximum number of trials.
254+
* \param cost_model The cost model of the scheduler.
255+
* \param measure_callbacks The measure callbacks of the scheduler.
256+
* \param alpha The parameter alpha to control gradient computation.
257+
* \param window_size The parameter to control backward window size.
258+
* \param seed The random seed.
259+
* \return The task scheduler created.
260+
*/
261+
TVM_DLL static TaskScheduler GradientBased(Array<TuneContext> tasks,
262+
Array<FloatImm> task_weights, //
263+
Builder builder, //
264+
Runner runner, //
265+
Database database, //
266+
int max_trials, //
267+
Optional<CostModel> cost_model, //
268+
Optional<Array<MeasureCallback>> measure_callbacks, //
269+
double alpha, //
270+
int window_size, //
271+
support::LinearCongruentialEngine::TRandState seed);
267272
/*!
268273
* \brief Create a task scheduler with customized methods on the python-side.
269274
* \param tasks The tasks to be tuned.
270275
* \param builder The builder of the scheduler.
271276
* \param runner The runner of the scheduler.
272277
* \param database The database of the scheduler.
278+
* \param max_trials The maximum number of trials.
273279
* \param cost_model The cost model of the scheduler.
274280
* \param measure_callbacks The measure callbacks of the scheduler.
275281
* \param f_tune The packed function of `Tune`.
276282
* \param f_initialize_task The packed function of `InitializeTask`.
277-
* \param f_set_task_stopped The packed function of `SetTaskStopped`.
278-
* \param f_is_task_running The packed function of `IsTaskRunning`.
283+
* \param f_touch_task The packed function of `TouchTask`.
279284
* \param f_join_running_task The packed function of `JoinRunningTask`.
280285
* \param f_next_task_id The packed function of `NextTaskId`.
281286
* \return The task scheduler created.
@@ -285,44 +290,14 @@ class TaskScheduler : public runtime::ObjectRef {
285290
Builder builder, //
286291
Runner runner, //
287292
Database database, //
293+
int max_trials, //
288294
Optional<CostModel> cost_model, //
289295
Optional<Array<MeasureCallback>> measure_callbacks, //
290296
PyTaskSchedulerNode::FTune f_tune, //
291297
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
292-
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
293-
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
298+
PyTaskSchedulerNode::FTouchTask f_touch_task, //
294299
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
295300
PyTaskSchedulerNode::FNextTaskId f_next_task_id);
296-
/*!
297-
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
298-
* \param tasks The tasks to be tuned.
299-
* \param builder The builder of the scheduler.
300-
* \param runner The runner of the scheduler.
301-
* \param database The database of the scheduler.
302-
* \param alpha The parameter alpha to control gradient computation.
303-
* \param beta The parameter beta to control gradient computation.
304-
* \param backward_window_size The parameter to control backward window size.
305-
* \param seed The random seed.
306-
* \param task_weights The weights of each task.
307-
* \param objective_func_name The name of objective function for gradient optimization.
308-
* \param tag_generation_func_name The name of function to generate similarity tag for workloads.
309-
* \param cost_model The cost model of the scheduler.
310-
* \param measure_callbacks The measure callbacks of the scheduler.
311-
* \return The task scheduler created.
312-
*/
313-
TVM_DLL static TaskScheduler GradientBased(Array<TuneContext> tasks, //
314-
Builder builder, //
315-
Runner runner, //
316-
Database database, //
317-
double alpha, //
318-
double beta, //
319-
int backward_window_size, //
320-
support::LinearCongruentialEngine::TRandState seed, //
321-
Array<FloatImm> task_weights, //
322-
String objective_func_name, //
323-
String tag_generation_func_name, //
324-
Optional<CostModel> cost_model, //
325-
Optional<Array<MeasureCallback>> measure_callbacks);
326301
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
327302
};
328303

include/tvm/meta_schedule/tune_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class TuneContextNode : public runtime::Object {
6262
/*! \brief The task scheduler that owns the tune context */
6363
const TaskSchedulerNode* task_scheduler;
6464
/*! \brief Whether the tuning task has been stopped or finished. */
65-
bool is_stopped;
65+
bool is_terminated;
6666
/*! \brief The measure candidates. */
6767
Optional<Array<MeasureCandidate>> measure_candidates;
6868
/*! \brief The building results. */
@@ -81,7 +81,7 @@ class TuneContextNode : public runtime::Object {
8181
v->Visit("task_name", &task_name);
8282
v->Visit("rand_state", &rand_state);
8383
v->Visit("num_threads", &num_threads);
84-
v->Visit("is_stopped", &is_stopped);
84+
v->Visit("is_terminated", &is_terminated);
8585
v->Visit("builder_results", &builder_results);
8686
v->Visit("runner_futures", &runner_futures);
8787
v->Visit("measure_candidates", &measure_candidates);

include/tvm/support/random_engine.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,15 @@ class LinearCongruentialEngine {
9999
* \brief Change the start random state of RNG with the seed of a new random state value.
100100
* \param rand_state The random state given in result_type.
101101
*/
102-
void Seed(TRandState rand_state = 1) {
103-
ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!";
104-
rand_state %= modulus; // Make sure the seed is within the range of modulus.
105-
if (rand_state == 0)
106-
rand_state = 1; // Avoid getting all 0 given the current parameter set.
107-
else if (rand_state < 0)
108-
rand_state += modulus; // Make sure the rand state is non-negative.
109-
ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null.
110-
*rand_state_ptr_ = rand_state; // Change pointed random state to given random state value.
102+
void Seed(TRandState rand_state) {
103+
if (rand_state == -1) {
104+
rand_state = DeviceRandom();
105+
} else if (rand_state == 0) {
106+
rand_state = 1;
107+
}
108+
ICHECK(rand_state >= 0) << "The random state should be nonnegative";
109+
ICHECK(rand_state_ptr_ != nullptr);
110+
*rand_state_ptr_ = rand_state % modulus;
111111
}
112112

113113
/*!

include/tvm/tir/schedule/schedule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class ScheduleNode : public runtime::Object {
128128
* \brief Seed the randomness
129129
* \param seed The new random seed, -1 if use device random, otherwise non-negative
130130
*/
131-
virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0;
131+
virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0;
132132
/*! \brief Fork the random state */
133133
virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;
134134

python/tvm/meta_schedule/search_strategy/evolutionary_search.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class EvolutionarySearch(SearchStrategy):
3434
----------
3535
num_trials_per_iter : int
3636
Number of trials per iteration.
37-
num_trials_total : int
37+
max_trials_per_task : int
3838
Total number of trials.
3939
population_size : int
4040
The initial population of traces from measured samples and randomly generated samples.
@@ -53,7 +53,7 @@ class EvolutionarySearch(SearchStrategy):
5353
"""
5454

5555
num_trials_per_iter: int
56-
num_trials_total: int
56+
max_trials_per_task: int
5757
population_size: int
5858
init_measured_ratio: int
5959
init_min_unmeasured: int
@@ -66,7 +66,7 @@ def __init__(
6666
self,
6767
*,
6868
num_trials_per_iter: int,
69-
num_trials_total: int,
69+
max_trials_per_task: int,
7070
population_size: int,
7171
init_measured_ratio: float,
7272
init_min_unmeasured: int,
@@ -79,7 +79,7 @@ def __init__(
7979
self.__init_handle_by_constructor__(
8080
_ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member
8181
num_trials_per_iter,
82-
num_trials_total,
82+
max_trials_per_task,
8383
population_size,
8484
init_measured_ratio,
8585
init_min_unmeasured,
@@ -94,7 +94,8 @@ class EvolutionarySearchConfig(NamedTuple):
9494
"""Configuration for EvolutionarySearch"""
9595

9696
num_trials_per_iter: int
97-
num_trials_total: int
97+
max_trials_per_task: int
98+
max_trials_global: int
9899
population_size: int = 2048
99100
init_measured_ratio: float = 0.2
100101
init_min_unmeasured: int = 50
@@ -106,7 +107,7 @@ class EvolutionarySearchConfig(NamedTuple):
106107
def create_strategy(self) -> EvolutionarySearch:
107108
return EvolutionarySearch(
108109
num_trials_per_iter=self.num_trials_per_iter,
109-
num_trials_total=self.num_trials_total,
110+
max_trials_per_task=self.max_trials_per_task,
110111
population_size=self.population_size,
111112
init_measured_ratio=self.init_measured_ratio,
112113
init_min_unmeasured=self.init_min_unmeasured,

0 commit comments

Comments
 (0)