Skip to content

Commit 1cc1563

Browse files
committed
Delegate class members' methods
1 parent 1dbf77d commit 1cc1563

File tree

9 files changed

+129
-41
lines changed

9 files changed

+129
-41
lines changed

include/tvm/meta_schedule/search_strategy.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,10 @@ class SearchStrategyNode : public runtime::Object {
113113

114114
/*!
115115
* \brief Update the search strategy with measurement results.
116-
* \param context The tuning context.
117116
* \param measure_candidates The candidates to be measured.
118117
* \param results The measurement results from the runner.
119118
*/
120-
virtual void NotifyRunnerResults(const TuneContext& context,
121-
const Array<MeasureCandidate>& measure_candidates,
119+
virtual void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
122120
const Array<RunnerResult>& results) = 0;
123121

124122
static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
@@ -150,8 +148,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
150148
* \brief The function type of `NotifyRunnerResults` method.
151149
* \param results The measurement results from the runner.
152150
*/
153-
using FNotifyRunnerResults = runtime::TypedPackedFunc<void(
154-
const TuneContext&, const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
151+
using FNotifyRunnerResults =
152+
runtime::TypedPackedFunc<void(const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
155153

156154
/*! \brief The packed function to the `InitializeWithTuneContext` method. */
157155
FInitializeWithTuneContext f_initialize_with_tune_context;
@@ -177,8 +175,7 @@ class PySearchStrategyNode : public SearchStrategyNode {
177175
const Optional<CostModel>& cost_model) final;
178176
void PostTuning() final;
179177
Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
180-
void NotifyRunnerResults(const TuneContext& context,
181-
const Array<MeasureCandidate>& measure_candidates,
178+
void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
182179
const Array<RunnerResult>& results);
183180

184181
static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";

python/tvm/meta_schedule/search_strategy/search_strategy.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,24 +129,20 @@ def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]:
129129

130130
def notify_runner_results(
131131
self,
132-
context: "TuneContext",
133132
measure_candidates: List[MeasureCandidate],
134133
results: List[RunnerResult],
135134
) -> None:
136135
"""Update the search strategy with profiling results.
137136
138137
Parameters
139138
----------
140-
context : TuneContext
141-
The tuning context for update.
142139
measure_candidates : List[MeasureCandidate]
143140
The measure candidates for update.
144141
results : List[RunnerResult]
145142
The profiling results from the runner.
146143
"""
147144
_ffi_api.SearchStrategyNotifyRunnerResults( # type: ignore # pylint: disable=no-member
148145
self,
149-
context,
150146
measure_candidates,
151147
results,
152148
)
@@ -236,16 +232,13 @@ def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]:
236232

237233
def notify_runner_results(
238234
self,
239-
context: "TuneContext",
240235
measure_candidates: List[MeasureCandidate],
241236
results: List[RunnerResult],
242237
) -> None:
243238
"""Update the search strategy with profiling results.
244239
245240
Parameters
246241
----------
247-
context : TuneContext
248-
The tuning context for update.
249242
measure_candidates : List[MeasureCandidate]
250243
The measure candidates for update.
251244
results : List[RunnerResult]

python/tvm/meta_schedule/tune_context.py

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,26 @@
1717
"""Meta Schedule tuning context."""
1818

1919
import logging
20-
from typing import Optional, List, Dict, TYPE_CHECKING
20+
from typing import TYPE_CHECKING, Dict, List, Optional
2121

2222
from tvm import IRModule
2323
from tvm._ffi import register_object
2424
from tvm.meta_schedule.utils import cpu_count, make_logging_func
2525
from tvm.runtime import Object
2626
from tvm.target import Target
27-
from tvm.tir import PrimFunc
27+
from tvm.tir import PrimFunc, Schedule
2828

2929
from . import _ffi_api
3030

3131
if TYPE_CHECKING:
32-
from .space_generator import SpaceGenerator
33-
from .search_strategy import SearchStrategy
34-
from .schedule_rule import ScheduleRule
35-
from .postproc import Postproc
32+
from .cost_model import CostModel
33+
from .database import Database
3634
from .mutator import Mutator
35+
from .postproc import Postproc
36+
from .runner import Runner, RunnerResult
37+
from .schedule_rule import ScheduleRule
38+
from .search_strategy import MeasureCandidate, SearchStrategy
39+
from .space_generator import SpaceGenerator
3740

3841

3942
@register_object("meta_schedule.TuneContext")
@@ -114,7 +117,6 @@ def __init__(
114117
self.logger = logging.getLogger(__name__)
115118
else:
116119
self.logger = None
117-
118120
self.__init_handle_by_constructor__(
119121
_ffi_api.TuneContext, # type: ignore # pylint: disable=no-member
120122
mod,
@@ -132,5 +134,108 @@ def __init__(
132134

133135
def initialize(self):
134136
"""Initialize the tuning context"""
135-
136137
_ffi_api.TuneContextInitialize(self) # type: ignore # pylint: disable=no-member
138+
139+
def generate_design_space(self, mod: IRModule) -> List[Schedule]:
140+
"""Generate design spaces given a module.
141+
142+
Delegated to self.space_generator.generate_design_space.
143+
144+
Parameters
145+
----------
146+
mod : IRModule
147+
The module used for design space generation.
148+
149+
Returns
150+
-------
151+
design_spaces : List[Schedule]
152+
The generated design spaces, i.e., schedules.
153+
"""
154+
if self.space_generator is None:
155+
raise ValueError(
156+
"space_generator is not provided."
157+
"Please construct TuneContext with space_generator"
158+
)
159+
return self.space_generator.generate_design_space(mod)
160+
161+
def pre_tuning(
162+
self,
163+
design_spaces: List[Schedule],
164+
database: Optional["Database"] = None,
165+
cost_model: Optional["CostModel"] = None,
166+
) -> None:
167+
"""A method to be called for SearchStrategy to do necessary preparation before tuning.
168+
169+
Delegated to self.search_strategy.pre_tuning.
170+
171+
Parameters
172+
----------
173+
design_spaces : List[Schedule]
174+
The design spaces used during tuning process.
175+
database : Optional[Database] = None
176+
The database used during tuning process.
177+
cost_model : Optional[CostModel] = None
178+
The cost model used during tuning process.
179+
"""
180+
if self.search_strategy is None:
181+
raise ValueError(
182+
"search_strategy is not provided."
183+
"Please construct TuneContext with search_strategy"
184+
)
185+
return self.search_strategy.pre_tuning(design_spaces, database, cost_model)
186+
187+
def post_tuning(self) -> None:
188+
"""A method to be called for SearchStrategy to do necessary cleanup after tuning.
189+
190+
Delegated to self.search_strategy.post_tuning.
191+
"""
192+
if self.search_strategy is None:
193+
raise ValueError(
194+
"search_strategy is not provided."
195+
"Please construct TuneContext with search_strategy"
196+
)
197+
_ffi_api.SearchStrategyPostTuning(self) # type: ignore # pylint: disable=no-member
198+
199+
def generate_measure_candidates(self) -> Optional[List["MeasureCandidate"]]:
200+
"""Generate a batch of measure candidates from design spaces for measurement.
201+
202+
Delegated to self.search_strategy.generate_measure_candidates.
203+
204+
Returns
205+
-------
206+
measure_candidates : Optional[List[IRModule]]
207+
The measure candidates generated, None if search is finished.
208+
"""
209+
if self.search_strategy is None:
210+
raise ValueError(
211+
"search_strategy is not provided."
212+
"Please construct TuneContext with search_strategy"
213+
)
214+
return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # type: ignore # pylint: disable=no-member
215+
216+
def notify_runner_results(
217+
self,
218+
measure_candidates: List["MeasureCandidate"],
219+
results: List["RunnerResult"],
220+
) -> None:
221+
"""Update the state in SearchStrategy with profiling results.
222+
223+
Delegated to self.search_strategy.notify_runner_results.
224+
225+
Parameters
226+
----------
227+
measure_candidates : List[MeasureCandidate]
228+
The measure candidates for update.
229+
results : List[RunnerResult]
230+
The profiling results from the runner.
231+
"""
232+
if self.search_strategy is None:
233+
raise ValueError(
234+
"search_strategy is not provided."
235+
"Please construct TuneContext with search_strategy"
236+
)
237+
_ffi_api.SearchStrategyNotifyRunnerResults( # type: ignore # pylint: disable=no-member
238+
self,
239+
measure_candidates,
240+
results,
241+
)

src/meta_schedule/search_strategy/evolutionary_search.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,7 @@ class EvolutionarySearchNode : public SearchStrategyNode {
314314
/*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */
315315
inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
316316
/*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */
317-
inline void NotifyRunnerResults(const TuneContext& context,
318-
const Array<MeasureCandidate>& measure_candidates,
317+
inline void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
319318
const Array<RunnerResult>& results);
320319
};
321320

@@ -430,11 +429,10 @@ class EvolutionarySearchNode : public SearchStrategyNode {
430429
return this->state_->GenerateMeasureCandidates();
431430
}
432431

433-
void NotifyRunnerResults(const TuneContext& context,
434-
const Array<MeasureCandidate>& measure_candidates,
432+
void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
435433
const Array<RunnerResult>& results) final {
436434
ICHECK(this->state_ != nullptr);
437-
this->state_->NotifyRunnerResults(context, measure_candidates, results);
435+
this->state_->NotifyRunnerResults(measure_candidates, results);
438436
}
439437
};
440438

@@ -681,8 +679,7 @@ Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasure
681679
}
682680

683681
void EvolutionarySearchNode::State::NotifyRunnerResults(
684-
const TuneContext& context, const Array<MeasureCandidate>& measure_candidates,
685-
const Array<RunnerResult>& results) {
682+
const Array<MeasureCandidate>& measure_candidates, const Array<RunnerResult>& results) {
686683
st += results.size();
687684
ed += results.size();
688685
}

src/meta_schedule/search_strategy/replay_func.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ class ReplayFuncNode : public SearchStrategyNode {
9898
return this->state_->GenerateMeasureCandidates();
9999
}
100100

101-
void NotifyRunnerResults(const TuneContext& context,
102-
const Array<MeasureCandidate>& measure_candidates,
101+
void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
103102
const Array<RunnerResult>& results) final {
104103
ICHECK(this->state_ != nullptr);
105104
this->state_->NotifyRunnerResults(results);

src/meta_schedule/search_strategy/replay_trace.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ class ReplayTraceNode : public SearchStrategyNode {
113113
return this->state_->GenerateMeasureCandidates();
114114
}
115115

116-
void NotifyRunnerResults(const TuneContext& context,
117-
const Array<MeasureCandidate>& measure_candidates,
116+
void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
118117
const Array<RunnerResult>& results) final {
119118
ICHECK(this->state_ != nullptr);
120119
this->state_->NotifyRunnerResults(results);

src/meta_schedule/search_strategy/search_strategy.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,11 @@ Optional<Array<MeasureCandidate>> PySearchStrategyNode::GenerateMeasureCandidate
5252
return f_generate_measure_candidates();
5353
}
5454

55-
void PySearchStrategyNode::NotifyRunnerResults(const TuneContext& context,
56-
const Array<MeasureCandidate>& measure_candidates,
55+
void PySearchStrategyNode::NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
5756
const Array<RunnerResult>& results) {
5857
ICHECK(f_notify_runner_results != nullptr)
5958
<< "PySearchStrategy's NotifyRunnerResults method not implemented!";
60-
f_notify_runner_results(context, measure_candidates, results);
59+
f_notify_runner_results(measure_candidates, results);
6160
}
6261

6362
SearchStrategy SearchStrategy::PySearchStrategy(

src/meta_schedule/task_scheduler/task_scheduler.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,7 @@ Array<RunnerResult> TaskSchedulerNode::JoinRunningTask(int task_id) {
182182
for (RunnerFuture future : futures) {
183183
results.push_back(future->Result());
184184
}
185-
task->search_strategy.value()->NotifyRunnerResults(task, task->measure_candidates.value(),
186-
results);
185+
task->search_strategy.value()->NotifyRunnerResults(task->measure_candidates.value(), results);
187186
// Invoke the callbacks
188187
ICHECK(task->measure_candidates.defined());
189188
ICHECK(task->builder_results.defined());

tests/python/unittest/test_meta_schedule_search_strategy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_meta_schedule_replay_func(
113113
error_msg=None,
114114
)
115115
)
116-
strategy.notify_runner_results(context, candidates, runner_results)
116+
strategy.notify_runner_results(candidates, runner_results)
117117
candidates = strategy.generate_measure_candidates()
118118
strategy.post_tuning()
119119
assert num_trials_each_iter == [7, 7, 6]
@@ -178,7 +178,7 @@ def _schedule_matmul_small(sch: Schedule):
178178
error_msg=None,
179179
)
180180
)
181-
strategy.notify_runner_results(context, candidates, runner_results)
181+
strategy.notify_runner_results(candidates, runner_results)
182182
candidates = strategy.generate_measure_candidates()
183183
strategy.post_tuning()
184184
assert sum(num_trials_each_iter) == 25
@@ -242,7 +242,7 @@ def _schedule_matmul_empty(sch: Schedule):
242242
error_msg=None,
243243
),
244244
)
245-
strategy.notify_runner_results(context, candidates, runner_results)
245+
strategy.notify_runner_results(candidates, runner_results)
246246
candidates = strategy.generate_measure_candidates()
247247
strategy.post_tuning()
248248
assert num_trials_each_iter == [1, 0, 0, 0, 0]

0 commit comments

Comments
 (0)