1717"""Meta Schedule tuning context."""
1818
1919import logging
20- from typing import Optional , List , Dict , TYPE_CHECKING
20+ from typing import TYPE_CHECKING , Dict , List , Optional
2121
2222from tvm import IRModule
2323from tvm ._ffi import register_object
2424from tvm .meta_schedule .utils import cpu_count , make_logging_func
2525from tvm .runtime import Object
2626from tvm .target import Target
27- from tvm .tir import PrimFunc
27+ from tvm .tir import PrimFunc , Schedule
2828
2929from . import _ffi_api
3030
3131if TYPE_CHECKING :
32- from .space_generator import SpaceGenerator
33- from .search_strategy import SearchStrategy
34- from .schedule_rule import ScheduleRule
35- from .postproc import Postproc
32+ from .cost_model import CostModel
33+ from .database import Database
3634 from .mutator import Mutator
35+ from .postproc import Postproc
36+ from .runner import Runner , RunnerResult
37+ from .schedule_rule import ScheduleRule
38+ from .search_strategy import MeasureCandidate , SearchStrategy
39+ from .space_generator import SpaceGenerator
3740
3841
3942@register_object ("meta_schedule.TuneContext" )
@@ -114,7 +117,6 @@ def __init__(
114117 self .logger = logging .getLogger (__name__ )
115118 else :
116119 self .logger = None
117-
118120 self .__init_handle_by_constructor__ (
119121 _ffi_api .TuneContext , # type: ignore # pylint: disable=no-member
120122 mod ,
@@ -132,5 +134,108 @@ def __init__(
132134
133135 def initialize (self ):
134136 """Initialize the tuning context"""
135-
136137 _ffi_api .TuneContextInitialize (self ) # type: ignore # pylint: disable=no-member
138+
139+ def generate_design_space (self , mod : IRModule ) -> List [Schedule ]:
140+ """Generate design spaces given a module.
141+
142+ Delegated to self.space_generator.generate_design_space.
143+
144+ Parameters
145+ ----------
146+ mod : IRModule
147+ The module used for design space generation.
148+
149+ Returns
150+ -------
151+ design_spaces : List[Schedule]
152+ The generated design spaces, i.e., schedules.
153+ """
154+ if self .space_generator is None :
155+ raise ValueError (
156+ "space_generator is not provided."
157+ "Please construct TuneContext with space_generator"
158+ )
159+ return self .space_generator .generate_design_space (mod )
160+
161+ def pre_tuning (
162+ self ,
163+ design_spaces : List [Schedule ],
164+ database : Optional ["Database" ] = None ,
165+ cost_model : Optional ["CostModel" ] = None ,
166+ ) -> None :
167+ """A method to be called for SearchStrategy to do necessary preparation before tuning.
168+
169+ Delegated to self.search_strategy.pre_tuning.
170+
171+ Parameters
172+ ----------
173+ design_spaces : List[Schedule]
174+ The design spaces used during tuning process.
175+ database : Optional[Database] = None
176+ The database used during tuning process.
177+ cost_model : Optional[CostModel] = None
178+ The cost model used during tuning process.
179+ """
180+ if self .search_strategy is None :
181+ raise ValueError (
182+ "search_strategy is not provided."
183+ "Please construct TuneContext with search_strategy"
184+ )
185+ return self .search_strategy .pre_tuning (design_spaces , database , cost_model )
186+
187+ def post_tuning (self ) -> None :
188+ """A method to be called for SearchStrategy to do necessary cleanup after tuning.
189+
190+ Delegated to self.search_strategy.post_tuning.
191+ """
192+ if self .search_strategy is None :
193+ raise ValueError (
194+ "search_strategy is not provided."
195+ "Please construct TuneContext with search_strategy"
196+ )
197+ _ffi_api .SearchStrategyPostTuning (self ) # type: ignore # pylint: disable=no-member
198+
199+ def generate_measure_candidates (self ) -> Optional [List ["MeasureCandidate" ]]:
200+ """Generate a batch of measure candidates from design spaces for measurement.
201+
202+ Delegated to self.search_strategy.generate_measure_candidates.
203+
204+ Returns
205+ -------
206+ measure_candidates : Optional[List[IRModule]]
207+ The measure candidates generated, None if search is finished.
208+ """
209+ if self .search_strategy is None :
210+ raise ValueError (
211+ "search_strategy is not provided."
212+ "Please construct TuneContext with search_strategy"
213+ )
214+ return _ffi_api .SearchStrategyGenerateMeasureCandidates (self ) # type: ignore # pylint: disable=no-member
215+
216+ def notify_runner_results (
217+ self ,
218+ measure_candidates : List ["MeasureCandidate" ],
219+ results : List ["RunnerResult" ],
220+ ) -> None :
221+ """Update the state in SearchStrategy with profiling results.
222+
223+ Delegated to self.search_strategy.notify_runner_results.
224+
225+ Parameters
226+ ----------
227+ measure_candidates : List[MeasureCandidate]
228+ The measure candidates for update.
229+ results : List[RunnerResult]
230+ The profiling results from the runner.
231+ """
232+ if self .search_strategy is None :
233+ raise ValueError (
234+ "search_strategy is not provided."
235+ "Please construct TuneContext with search_strategy"
236+ )
237+ _ffi_api .SearchStrategyNotifyRunnerResults ( # type: ignore # pylint: disable=no-member
238+ self ,
239+ measure_candidates ,
240+ results ,
241+ )
0 commit comments