diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h new file mode 100644 index 0000000000..e1a4a02265 --- /dev/null +++ b/include/tvm/meta_schedule/measure_callback.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ +#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +class TaskScheduler; + +/*! \brief Rules to apply after measure results is available. */ +class MeasureCallbackNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MeasureCallbackNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Apply a measure callback rule with given arguments. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builds The builder results by building the measure candidates. + * \param results The runner results by running the built measure candidates. + * \return Whether the measure callback was successfully applied. + */ + virtual bool Apply(const TaskScheduler& task_scheduler, // + const Array tasks, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results) = 0; + + static constexpr const char* _type_key = "meta_schedule.MeasureCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); +}; + +/*! \brief The measure callback with customized methods on the python-side. */ +class PyMeasureCallbackNode : public MeasureCallbackNode { + public: + /*! + * \brief Apply a measure callback to the given schedule. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builds The builder results by building the measure candidates. + * \param results The runner results by running the built measure candidates. + * \return Whether the measure callback was successfully applied. + */ + using FApply = + runtime::TypedPackedFunc tasks, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results)>; + /*! + * \brief Get the measure callback function as string with name. + * \return The string of the measure callback function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Apply` funcion. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_apply` is not visited + // `f_as_string` is not visited + } + + bool Apply(const TaskScheduler& task_scheduler, // + const Array tasks, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results) final { + return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results); + } + + static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode); +}; + +/*! + * \brief Managed reference to MeasureCallbackNode + * \sa MeasureCallbackNode + */ +class MeasureCallback : public runtime::ObjectRef { + public: + /*! + * \brief Create a measure callback with customized methods on the python-side. + * \param f_apply The packed function of `Apply`. + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // + PyMeasureCallbackNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 0ea047397a..9c592b2794 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -73,6 +74,8 @@ class TaskSchedulerNode : public runtime::Object { Runner runner{nullptr}; /*! \brief The database of the scheduler. */ Database database{nullptr}; + /*! \brief The list of measure callbacks of the scheduler. */ + Array measure_callbacks; /*! \brief The default desctructor. */ virtual ~TaskSchedulerNode() = default; @@ -82,6 +85,7 @@ class TaskSchedulerNode : public runtime::Object { v->Visit("builder", &builder); v->Visit("runner", &runner); v->Visit("database", &database); + v->Visit("measure_callbacks", &measure_callbacks); } /*! \brief Auto-tuning. */ @@ -222,8 +226,14 @@ class TaskScheduler : public runtime::ObjectRef { TVM_DLL static TaskScheduler RoundRobin(Array tasks, // Builder builder, // Runner runner, // - Database database); + Database database, // + Array measure_callbacks); TVM_DLL static TaskScheduler PyTaskScheduler( + Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Array measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // diff --git a/python/tvm/meta_schedule/measure_callback/__init__.py b/python/tvm/meta_schedule/measure_callback/__init__.py new file mode 100644 index 0000000000..f455c1f4c7 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.measure_callback package. +""" +from .measure_callback import MeasureCallback, PyMeasureCallback diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py new file mode 100644 index 0000000000..8d487905ea --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule MeasureCallback.""" + +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.builder import BuilderResult +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.utils import _get_hex_address + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + from ..task_scheduler import TaskScheduler + + +@register_object("meta_schedule.MeasureCallback") +class MeasureCallback(Object): + """Rules to apply after measure results is available.""" + + def apply( + self, + task_scheduler: "TaskScheduler", + tasks: List["TuneContext"], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + """Apply a measure callback to the given schedule. + + Parameters + ---------- + task_scheduler: TaskScheduler + The task scheduler. + tasks: List[TuneContext] + The list of tune context to process. + measure_candidats: List[MeasureCandidate] + The measure candidates. + builds: List[BuilderResult] + The builder results by building the measure candidates. + results: List[RunnerResult] + The runner results by running the built measure candidates. + + Returns + ------- + result : bool + Whether the measure callback was successfully applied. + """ + return _ffi_api.MeasureCallbackApply( + self, task_scheduler, tasks, measure_candidates, builds, results + ) + + +@register_object("meta_schedule.PyMeasureCallback") +class PyMeasureCallback(MeasureCallback): + """An abstract MeasureCallback with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_apply( + task_scheduler: "TaskScheduler", + tasks: List[TuneContext], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + return self.apply(task_scheduler, tasks, measure_candidates, builds, results) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackPyMeasureCallback, # type: ignore # pylint: disable=no-member + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyMeasureCallback({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index 4e675e4ea7..e33431e15c 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -37,7 +37,7 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: Parameters ---------- tune_context : TuneContext - The tuning context for initializing the design space generator. + The tuning context for initializing the mutator. """ _ffi_api.MutatorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member self, tune_context diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index c998ffc028..930a04f471 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -46,7 +46,7 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: Parameters ---------- tune_context : TuneContext - The tuning context for initializing the design space generator. + The tuning context for initializing the post processing. """ _ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member self, tune_context diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 1d90622c28..1e10fe4183 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -41,7 +41,7 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: Parameters ---------- tune_context : TuneContext - The tuning context for initializing the design space generator. + The tuning context for initializing the schedule rule. """ _ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member self, tune_context diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index f6fd50c15a..e306c307bc 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -20,6 +20,6 @@ to generate measure candidates. """ -from .search_strategy import SearchStrategy, PySearchStrategy +from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate from .replay_trace import ReplayTrace from .replay_func import ReplayFunc diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 391011b4f5..ab2d11ea7b 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -19,6 +19,7 @@ from typing import List, TYPE_CHECKING from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from ..builder import Builder from ..runner import Runner @@ -41,6 +42,7 @@ def __init__( builder: Builder, runner: Runner, database: Database, + measure_callbacks: List[MeasureCallback] = [], ) -> None: """Constructor. @@ -54,6 +56,8 @@ def __init__( The runner. database : Database The database. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. """ self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member @@ -61,4 +65,5 @@ def __init__( builder, runner, database, + measure_callbacks, ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index 531203f375..1e4042549e 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -15,15 +15,43 @@ # specific language governing permissions and limitations # under the License. """Auto-tuning Task Scheduler""" + +from typing import List + from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from tvm.runtime import Object +from ..runner import Runner +from ..builder import Builder +from ..database import Database +from ..tune_context import TuneContext from .. import _ffi_api @register_object("meta_schedule.TaskScheduler") class TaskScheduler(Object): - """The abstract task scheduler interface.""" + """The abstract task scheduler interface. + + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. + """ + + tasks: List[TuneContext] + builder: Builder + runner: Runner + database: Database + measure_callbacks: List[MeasureCallback] def tune(self) -> None: """Auto-tuning.""" @@ -89,8 +117,29 @@ def _next_task_id(self) -> int: class PyTaskScheduler(TaskScheduler): """An abstract task scheduler with customized methods on the python-side.""" - def __init__(self): - """Constructor.""" + def __init__( + self, + tasks: List[TuneContext], + builder: Builder, + runner: Runner, + database: Database, + measure_callbacks: List[MeasureCallback] = [], + ): + """Constructor. + + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. + """ def f_tune() -> None: self.tune() @@ -112,6 +161,11 @@ def f_next_task_id() -> int: self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerPyTaskScheduler, # pylint: disable=no-member + tasks, + builder, + runner, + database, + measure_callbacks, f_tune, f_initialize_task, f_set_task_stopped, diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc new file mode 100644 index 0000000000..8caee9447d --- /dev/null +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // + PyMeasureCallbackNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return MeasureCallback(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMeasureCallbackNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr); + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") + .set_body_method(&MeasureCallbackNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") + .set_body_typed(MeasureCallback::PyMeasureCallback); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index a529f2354d..2bd7cf4bcf 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -52,13 +52,17 @@ class RoundRobinNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::RoundRobin(Array tasks, Builder builder, Runner runner, - Database database) { +TaskScheduler TaskScheduler::RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Array measure_callbacks) { ObjectPtr n = make_object(); n->tasks = tasks; n->builder = builder; n->runner = runner; n->database = database; + n->measure_callbacks = measure_callbacks; n->task_id = -1; return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index fe5156633a..3bd2ab9dde 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -207,6 +207,11 @@ void TaskSchedulerNode::JoinRunningTask(int task_id) { } TaskScheduler TaskScheduler::PyTaskScheduler( + Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Array measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // @@ -214,6 +219,11 @@ TaskScheduler TaskScheduler::PyTaskScheduler( PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // PyTaskSchedulerNode::FNextTaskId f_next_task_id) { ObjectPtr n = make_object(); + n->tasks = tasks; + n->builder = builder; + n->runner = runner; + n->database = database; + n->measure_callbacks = measure_callbacks; n->f_tune = f_tune; n->f_initialize_task = f_initialize_task; n->f_set_task_stopped = f_set_task_stopped; diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index d94f740a42..be76d3e8db 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py new file mode 100644 index 0000000000..9b7df907b0 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import re +from typing import List + +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.meta_schedule.runner.runner import Runner +from tvm.meta_schedule.task_scheduler.task_scheduler import TaskScheduler +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T + +from tvm.meta_schedule.measure_callback import PyMeasureCallback +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.builder import BuilderResult +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.utils import _get_hex_address + +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_measure_callback(): + class FancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + tasks: List[TuneContext], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + assert len(measure_candidates) == 1 + assert_structural_equal(measure_candidates[0].sch.mod, Matmul) + assert ( + len(builds) == 1 + and builds[0].error_msg is None + and builds[0].artifact_path == "test_build" + ) + assert ( + len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2 + ) + return True + + measure_callback = FancyMeasureCallback() + assert measure_callback.apply( + TaskScheduler(), + [], + [MeasureCandidate(Schedule(Matmul), None)], + [BuilderResult("test_build", None)], + [RunnerResult([1.0, 2.1], None)], + ) + + +def test_meta_schedule_measure_callback_fail(): + class FailingMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + tasks: List[TuneContext], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + return False + + measure_callback = FailingMeasureCallback() + assert not measure_callback.apply( + TaskScheduler(), + [], + [MeasureCandidate(None, None)], + [BuilderResult(None, None)], + [RunnerResult(None, None)], + ) + + +def test_meta_schedule_measure_callback_as_string(): + class NotSoFancyMeasureCallback(PyMeasureCallback): + def __str__(self) -> str: + return f"NotSoFancyMeasureCallback({_get_hex_address(self.handle)})" + + measure_callback = NotSoFancyMeasureCallback() + pattern = re.compile(r"NotSoFancyMeasureCallback\(0x[a-f|0-9]*\)") + assert pattern.match(str(measure_callback)) + + +if __name__ == "__main__": + test_meta_schedule_measure_callback() + test_meta_schedule_measure_callback_fail() + test_meta_schedule_measure_callback_as_string() diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py index c7c981d71b..fc7926dc4a 100644 --- a/tests/python/unittest/test_meta_schedule_postproc.py +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -90,12 +90,12 @@ def apply(self, sch: Schedule) -> bool: def test_meta_schedule_postproc_as_string(): - class NotSoFancyPostProc(PyPostproc): + class NotSoFancyPostproc(PyPostproc): def __str__(self) -> str: - return f"NotSoFancyPostProc({_get_hex_address(self.handle)})" + return f"NotSoFancyPostproc({_get_hex_address(self.handle)})" - postproc = NotSoFancyPostProc() - pattern = re.compile(r"NotSoFancyPostProc\(0x[a-f|0-9]*\)") + postproc = NotSoFancyPostproc() + pattern = re.compile(r"NotSoFancyPostproc\(0x[a-f|0-9]*\)") assert pattern.match(str(postproc))