Skip to content

Commit a244e90

Browse files
committed
Checkpoint.
Fix cost model comment. Finish evolutionary seaarch. Remove extra code. Fix compile. Add comments. Add python part. Ad test. Update other files & comments.
1 parent 9a3b851 commit a244e90

File tree

13 files changed

+1038
-21
lines changed

13 files changed

+1038
-21
lines changed

include/tvm/meta_schedule/cost_model.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ class CostModelNode : public runtime::Object {
6161
const Array<RunnerResult>& results) = 0;
6262

6363
/*!
64-
* \brief Predict the running results of given measure candidates.
64+
* \brief Predict the normalized score (the larger the better) of given measure candidates.
6565
* \param tune_context The tuning context.
6666
* \param candidates The measure candidates.
67-
* \return The predicted running results.
67+
* \return The predicted normalized score.
6868
*/
6969
virtual std::vector<double> Predict(const TuneContext& tune_context,
7070
const Array<MeasureCandidate>& candidates) = 0;

include/tvm/meta_schedule/search_strategy.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
2121

2222
#include <tvm/meta_schedule/arg_info.h>
23+
#include <tvm/meta_schedule/database.h>
24+
#include <tvm/meta_schedule/mutator.h>
2325
#include <tvm/meta_schedule/runner.h>
2426
#include <tvm/meta_schedule/space_generator.h>
2527
#include <tvm/tir/schedule/schedule.h>
@@ -29,6 +31,7 @@ namespace meta_schedule {
2931

3032
// Forward declaration
3133
class TuneContext;
34+
class CostModel;
3235

3336
/*! \brief The schedule (with input shapes) to be measured. */
3437
class MeasureCandidateNode : public runtime::Object {
@@ -255,6 +258,17 @@ class SearchStrategy : public runtime::ObjectRef {
255258
*/
256259
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);
257260

261+
TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
262+
int num_trials_total, //
263+
int population, //
264+
double init_measured_ratio, //
265+
int genetic_algo_iters, //
266+
double p_mutate, //
267+
double eps_greedy, //
268+
Map<Mutator, FloatImm> mutator_probs, //
269+
Database database, //
270+
CostModel cost_model);
271+
258272
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
259273
};
260274

python/tvm/meta_schedule/cost_model/cost_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate])
9797
Return
9898
------
9999
result : bool
100-
The predicted running results.
100+
The predicted normalized score.
101101
"""
102102
n = len(candidates)
103103
results = np.zeros(shape=(n,), dtype="float64")
@@ -127,7 +127,6 @@ def f_save(file_location: str) -> bool:
127127

128128
@check_override(self.__class__, CostModel)
129129
def f_update(
130-
self,
131130
tune_context: TuneContext,
132131
candidates: List[MeasureCandidate],
133132
results: List[RunnerResult],

python/tvm/meta_schedule/search_strategy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate
2424
from .replay_trace import ReplayTrace
2525
from .replay_func import ReplayFunc
26+
from .evolutionary_search import EvolutionarySearch
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Evolutionary Search Strategy"""
18+
19+
from typing import TYPE_CHECKING, Dict
20+
21+
from tvm._ffi import register_object
22+
from ...tir import FloatImm
23+
24+
from .search_strategy import SearchStrategy
25+
from ..mutator import Mutator
26+
from ..database import Database
27+
28+
from .. import _ffi_api
29+
30+
if TYPE_CHECKING:
31+
from ..cost_model import CostModel
32+
33+
34+
@register_object("meta_schedule.EvolutionarySearch")
35+
class EvolutionarySearch(SearchStrategy):
36+
"""
37+
Replay Trace Search Strategy is a search strategy that always replays the trace by removing its
38+
decisions so that the decisions would be randomly re-generated.
39+
40+
Parameters
41+
----------
42+
num_trials_per_iter : int
43+
Number of trials per iteration.
44+
num_trials_total : int
45+
Total number of trials.
46+
population : int
47+
The initial population of traces from measured samples and randomly generated samples.
48+
init_measured_ratio : int
49+
The ratio of measured samples in the initial population.
50+
genetic_algo_iters : int
51+
The number of iterations for genetic algorithm.
52+
p_mutate : float
53+
The probability of mutation.
54+
eps_greedy : float
55+
The ratio of greedy selected samples in the final picks.
56+
mutator_probs: Dict[Mutator, FloatImm]
57+
The probability contribution of all mutators.
58+
database : Database
59+
The database used in the search.
60+
cost_model : CostModel
61+
The cost model used in the search.
62+
"""
63+
64+
num_trials_per_iter: int
65+
num_trials_total: int
66+
population: int
67+
init_measured_ratio: int
68+
genetic_algo_iters: int
69+
p_mutate: float
70+
eps_greedy: float
71+
mutator_probs: Dict[Mutator, FloatImm]
72+
database: Database
73+
cost_model: "CostModel"
74+
75+
def __init__(
76+
self,
77+
num_trials_per_iter: int,
78+
num_trials_total: int,
79+
population: int,
80+
init_measured_ratio: float,
81+
genetic_algo_iters: int,
82+
p_mutate: float,
83+
eps_greedy: float,
84+
mutator_probs: Dict[Mutator, FloatImm],
85+
database: Database,
86+
cost_model: "CostModel",
87+
):
88+
"""Constructor"""
89+
self.__init_handle_by_constructor__(
90+
_ffi_api.SearchStrategyEvolutionarySearch, # pylint: disable=no-member
91+
num_trials_per_iter,
92+
num_trials_total,
93+
population,
94+
init_measured_ratio,
95+
genetic_algo_iters,
96+
p_mutate,
97+
eps_greedy,
98+
mutator_probs,
99+
database,
100+
cost_model,
101+
)

python/tvm/meta_schedule/search_strategy/search_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from tvm._ffi import register_object
2424
from tvm.runtime import Object
25-
from tvm.tir.schedule import Schedule, Trace
25+
from tvm.tir.schedule import Schedule
2626

2727
from .. import _ffi_api
2828
from ..arg_info import ArgInfo

python/tvm/tir/schedule/schedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def sample_compute_location(
365365
decision: Optional[int] = None,
366366
) -> LoopRV:
367367
"""Sample a compute-at location on a BlockRV so that its producer can compute at that loop
368-
368+
369369
Parameters
370370
----------
371371
block : BlockRV

0 commit comments

Comments
 (0)