@@ -67,11 +67,6 @@ namespace meta_schedule {
6767*/
6868class TaskSchedulerNode : public runtime ::Object {
6969 public:
70- /* ! \brief The function type of the objective function. */
71- using FObjectiveFunc = TypedPackedFunc<FloatImm(Array<FloatImm>)>;
72- /* ! \brief The function type of the tag genration function. */
73- using FTagGenerationFunc = TypedPackedFunc<String(const IRModule&)>;
74-
7570 /* ! \brief The tasks to be tuned */
7671 Array<TuneContext> tasks;
7772 /* ! \brief The builder of the scheduler. */
@@ -80,10 +75,14 @@ class TaskSchedulerNode : public runtime::Object {
8075 Runner runner{nullptr };
8176 /* ! \brief The database of the scheduler. */
8277 Database database{nullptr };
78+ /* ! \brief The maximum number of trials allowed. */
79+ int max_trials;
8380 /* ! \brief The cost model of the scheduler. */
8481 Optional<CostModel> cost_model;
8582 /* ! \brief The list of measure callbacks of the scheduler. */
8683 Array<MeasureCallback> measure_callbacks;
84+ /* ! \brief The number of trials already conducted. */
85+ int num_trials_already;
8786
8887 /* ! \brief The default destructor. */
8988 virtual ~TaskSchedulerNode () = default ;
@@ -93,8 +92,10 @@ class TaskSchedulerNode : public runtime::Object {
9392 v->Visit (" builder" , &builder);
9493 v->Visit (" runner" , &runner);
9594 v->Visit (" database" , &database);
95+ v->Visit (" max_trials" , &max_trials);
9696 v->Visit (" cost_model" , &cost_model);
9797 v->Visit (" measure_callbacks" , &measure_callbacks);
98+ v->Visit (" num_trials_already" , &num_trials_already);
9899 }
99100
100101 /* ! \brief Auto-tuning. */
@@ -107,23 +108,16 @@ class TaskSchedulerNode : public runtime::Object {
107108 virtual void InitializeTask (int task_id);
108109
109110 /* !
110- * \brief Set specific task to be stopped.
111- * \param task_id The task id to be stopped.
112- */
113- virtual void SetTaskStopped (int task_id);
114-
115- /* !
116- * \brief Check whether the task is running.
111+ * \brief Touch the task and update its status
117112 * \param task_id The task id to be checked.
118- * \return Whether the task is running.
119113 */
120- virtual bool IsTaskRunning (int task_id);
114+ virtual void TouchTask (int task_id);
121115
122116 /* !
123117 * \brief Wait until the task is finished.
124118 * \param task_id The task id to be joined.
125119 */
126- virtual void JoinRunningTask (int task_id);
120+ virtual Array<RunnerResult> JoinRunningTask (int task_id);
127121
128122 /* !
129123 * \brief Fetch the next task id.
@@ -147,23 +141,17 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
147141 using FInitializeTask = runtime::TypedPackedFunc<void (int )>;
148142
149143 /* !
150- * \brief The function type of `SetTaskStopped` method.
151- * \param task_id The task id to be stopped.
152- */
153- using FSetTaskStopped = runtime::TypedPackedFunc<void (int )>;
154-
155- /* !
156- * \brief The function type of `IsTaskRunning` method.
144+ * \brief The function type of `TouchTask` method.
157145 * \param task_id The task id to be checked.
158146 * \return Whether the task is running.
159147 */
160- using FIsTaskRunning = runtime::TypedPackedFunc<bool (int )>;
148+ using FTouchTask = runtime::TypedPackedFunc<void (int )>;
161149
162150 /* !
163151 * \brief The function type of `JoinRunningTask` method.
164152 * \param task_id The task id to be joined.
165153 */
166- using FJoinRunningTask = runtime::TypedPackedFunc<void (int )>;
154+ using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult> (int )>;
167155
168156 /* !
169157 * \brief The function type of `NextTaskId` method.
@@ -175,10 +163,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
175163 FTune f_tune;
176164 /* ! \brief The packed function to the `InitializeTask` function. */
177165 FInitializeTask f_initialize_task;
178- /* ! \brief The packed function to the `SetTaskStopped` function. */
179- FSetTaskStopped f_set_task_stopped;
180- /* ! \brief The packed function to the `IsTaskRunning` function. */
181- FIsTaskRunning f_is_task_running;
166+ /* ! \brief The packed function to the `TouchTask` function. */
167+ FTouchTask f_touch_task;
182168 /* ! \brief The packed function to the `JoinRunningTask` function. */
183169 FJoinRunningTask f_join_running_task;
184170 /* ! \brief The packed function to the `NextTaskId` function. */
@@ -187,8 +173,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
187173 void VisitAttrs (tvm::AttrVisitor* v) {
188174 // `f_tune` is not visited
189175 // `f_initialize_task` is not visited
190- // `f_set_task_stopped` is not visited
191- // `f_is_task_running` is not visited
176+ // `f_touch_task` is not visited
192177 // `f_join_running_task` is not visited
193178 // `f_next_task_id` is not visited
194179 }
@@ -209,23 +194,15 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
209194 }
210195 }
211196
212- void SetTaskStopped (int task_id) final {
213- if (f_set_task_stopped == nullptr ) {
214- TaskSchedulerNode::SetTaskStopped (task_id);
215- } else {
216- f_set_task_stopped (task_id);
217- }
218- }
219-
220- bool IsTaskRunning (int task_id) final {
221- if (f_is_task_running == nullptr ) {
222- return TaskSchedulerNode::IsTaskRunning (task_id);
197+ void TouchTask (int task_id) final {
198+ if (f_touch_task == nullptr ) {
199+ return TaskSchedulerNode::TouchTask (task_id);
223200 } else {
224- return f_is_task_running (task_id);
201+ return f_touch_task (task_id);
225202 }
226203 }
227204
228- void JoinRunningTask (int task_id) final {
205+ Array<RunnerResult> JoinRunningTask (int task_id) final {
229206 if (f_join_running_task == nullptr ) {
230207 return TaskSchedulerNode::JoinRunningTask (task_id);
231208 } else {
@@ -254,6 +231,7 @@ class TaskScheduler : public runtime::ObjectRef {
254231 * \param builder The builder of the scheduler.
255232 * \param runner The runner of the scheduler.
256233 * \param database The database of the scheduler.
234+ * \param max_trials The maximum number of trials.
257235 * \param cost_model The cost model of the scheduler.
258236 * \param measure_callbacks The measure callbacks of the scheduler.
259237 * \return The task scheduler created.
@@ -262,20 +240,47 @@ class TaskScheduler : public runtime::ObjectRef {
262240 Builder builder, //
263241 Runner runner, //
264242 Database database, //
243+ int max_trials, //
265244 Optional<CostModel> cost_model, //
266245 Optional<Array<MeasureCallback>> measure_callbacks);
246+ /* !
247+ * \brief Create a task scheduler that fetches tasks in a gradient based fashion.
248+ * \param tasks The tasks to be tuned.
249+ * \param task_weights The weights of each task.
250+ * \param builder The builder of the scheduler.
251+ * \param runner The runner of the scheduler.
252+ * \param database The database of the scheduler.
253+ * \param max_trials The maximum number of trials.
254+ * \param cost_model The cost model of the scheduler.
255+ * \param measure_callbacks The measure callbacks of the scheduler.
256+ * \param alpha The parameter alpha to control gradient computation.
257+ * \param window_size The parameter to control backward window size.
258+ * \param seed The random seed.
259+ * \return The task scheduler created.
260+ */
261+ TVM_DLL static TaskScheduler GradientBased (Array<TuneContext> tasks,
262+ Array<FloatImm> task_weights, //
263+ Builder builder, //
264+ Runner runner, //
265+ Database database, //
266+ int max_trials, //
267+ Optional<CostModel> cost_model, //
268+ Optional<Array<MeasureCallback>> measure_callbacks, //
269+ double alpha, //
270+ int window_size, //
271+ support::LinearCongruentialEngine::TRandState seed);
267272 /* !
268273 * \brief Create a task scheduler with customized methods on the python-side.
269274 * \param tasks The tasks to be tuned.
270275 * \param builder The builder of the scheduler.
271276 * \param runner The runner of the scheduler.
272277 * \param database The database of the scheduler.
278+ * \param max_trials The maximum number of trials.
273279 * \param cost_model The cost model of the scheduler.
274280 * \param measure_callbacks The measure callbacks of the scheduler.
275281 * \param f_tune The packed function of `Tune`.
276282 * \param f_initialize_task The packed function of `InitializeTask`.
277- * \param f_set_task_stopped The packed function of `SetTaskStopped`.
278- * \param f_is_task_running The packed function of `IsTaskRunning`.
283+ * \param f_touch_task The packed function of `TouchTask`.
279284 * \param f_join_running_task The packed function of `JoinRunningTask`.
280285 * \param f_next_task_id The packed function of `NextTaskId`.
281286 * \return The task scheduler created.
@@ -285,44 +290,14 @@ class TaskScheduler : public runtime::ObjectRef {
285290 Builder builder, //
286291 Runner runner, //
287292 Database database, //
293+ int max_trials, //
288294 Optional<CostModel> cost_model, //
289295 Optional<Array<MeasureCallback>> measure_callbacks, //
290296 PyTaskSchedulerNode::FTune f_tune, //
291297 PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
292- PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
293- PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
298+ PyTaskSchedulerNode::FTouchTask f_touch_task, //
294299 PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
295300 PyTaskSchedulerNode::FNextTaskId f_next_task_id);
296- /* !
297- * \brief Create a task scheduler that fetches tasks in a gradient based fashion.
298- * \param tasks The tasks to be tuned.
299- * \param builder The builder of the scheduler.
300- * \param runner The runner of the scheduler.
301- * \param database The database of the scheduler.
302- * \param alpha The parameter alpha to control gradient computation.
303- * \param beta The parameter beta to control gradient computation.
304- * \param backward_window_size The parameter to control backward window size.
305- * \param seed The random seed.
306- * \param task_weights The weights of each task.
307- * \param objective_func_name The name of objective function for gradient optimization.
308- * \param tag_generation_func_name The name of function to generate similarity tag for workloads.
309- * \param cost_model The cost model of the scheduler.
310- * \param measure_callbacks The measure callbacks of the scheduler.
311- * \return The task scheduler created.
312- */
313- TVM_DLL static TaskScheduler GradientBased (Array<TuneContext> tasks, //
314- Builder builder, //
315- Runner runner, //
316- Database database, //
317- double alpha, //
318- double beta, //
319- int backward_window_size, //
320- support::LinearCongruentialEngine::TRandState seed, //
321- Array<FloatImm> task_weights, //
322- String objective_func_name, //
323- String tag_generation_func_name, //
324- Optional<CostModel> cost_model, //
325- Optional<Array<MeasureCallback>> measure_callbacks);
326301 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS (TaskScheduler, ObjectRef, TaskSchedulerNode);
327302};
328303
0 commit comments