Skip to content
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class CostModelNode : public runtime::Object {
const Array<RunnerResult>& results) = 0;

/*!
* \brief Predict the running results of given measure candidates.
* \brief Predict the normalized score (the larger the better) of given measure candidates.
* \param tune_context The tuning context.
* \param candidates The measure candidates.
* \return The predicted running results.
* \return The predicted normalized score.
*/
virtual std::vector<double> Predict(const TuneContext& tune_context,
const Array<MeasureCandidate>& candidates) = 0;
Expand Down
24 changes: 23 additions & 1 deletion include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ class DatabaseNode : public runtime::Object {
public:
/*! \brief Default destructor */
virtual ~DatabaseNode() = default;
/*!
* \brief Check if the database has the given workload.
* \param mod The IRModule to be searched for.
* \return Whether the database has the given workload.
*/
virtual bool HasWorkload(const IRModule& mod) = 0;
/*!
* \brief Look up or add workload to the database if missing.
* \param mod The IRModule to be searched for or added.
Expand Down Expand Up @@ -186,6 +192,12 @@ class DatabaseNode : public runtime::Object {
/*! \brief The database with customized methods on the python-side. */
class PyDatabaseNode : public DatabaseNode {
public:
/*!
* \brief The function type of `HasWorkload` method.
* \param mod The IRModule to be searched for.
* \return Whether the database has the given workload.
*/
using FHasWorkload = runtime::TypedPackedFunc<bool(const IRModule&)>;
/*!
* \brief The function type of `CommitWorkload` method.
* \param mod The IRModule to be searched for or added.
Expand All @@ -210,6 +222,8 @@ class PyDatabaseNode : public DatabaseNode {
*/
using FSize = runtime::TypedPackedFunc<int64_t()>;

/*! \brief The packed function to the `HasWorkload` function. */
FHasWorkload f_has_workload;
/*! \brief The packed function to the `CommitWorkload` function. */
FCommitWorkload f_commit_workload;
/*! \brief The packed function to the `CommitTuningRecord` function. */
Expand All @@ -224,12 +238,18 @@ class PyDatabaseNode : public DatabaseNode {
// so it cannot be accessible on the python side. If there is such need from the future,
// we can then add corresponding accessor methods to help access on python.
//
// `f_has_workload` is not visited
// `f_commit_workload` is not visited
// `f_commit_tuning_record` is not visited
// `f_get_top_k` is not visited
// `f_size` is not visited
}

bool HasWorkload(const IRModule& mod) final {
ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
return f_has_workload(mod);
}

Workload CommitWorkload(const IRModule& mod) final {
ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
return f_commit_workload(mod);
Expand Down Expand Up @@ -271,13 +291,15 @@ class Database : public runtime::ObjectRef {
bool allow_missing);
/*!
* \brief Create a database with customized methods on the python-side.
* \param f_has_workload The packed function of `HasWorkload`.
* \param f_commit_workload The packed function of `CommitWorkload`.
* \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
* \param f_get_top_k The packed function of `GetTopK`.
* \param f_size The packed function of `Size`.
* \return The created database.
*/
TVM_DLL static Database PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload,
TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
PyDatabaseNode::FCommitWorkload f_commit_workload,
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FSize f_size);
Expand Down
29 changes: 28 additions & 1 deletion include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@

#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {

// Forward declaration
class TuneContext;
class CostModel;
class Database;

/*! \brief The schedule (with input shapes) to be measured. */
class MeasureCandidateNode : public runtime::Object {
Expand Down Expand Up @@ -255,6 +256,32 @@ class SearchStrategy : public runtime::ObjectRef {
*/
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);

/*!
* \brief Constructor of evolutionary search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param num_trials_total The total number of trials for evolutionary search.
* \param population The initial sample population.
* \param max_replay_fail_cnt The maximum number to fail trace replaying.
* \param init_measured_ratio The ratio of measures samples in initial population.
* \param genetic_algo_iters The iterations to run the genetic algorithm.
* \param max_evolve_fail_cnt The maximum number to try evolving the given trace.
* \param p_mutate The probability of mutation.
* \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
* \param database The database to use.
* \param cost_model The cost model to use.
*/
TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
int num_trials_total, //
int population, //
int max_replay_fail_cnt, //
double init_measured_ratio, //
int genetic_algo_iters, //
int max_evolve_fail_cnt, //
double p_mutate, //
double eps_greedy, //
Database database, //
CostModel cost_model);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
};

Expand Down
10 changes: 5 additions & 5 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class TuneContextNode : public runtime::Object {
Array<ScheduleRule> sch_rules;
/*! \brief The postprocessors. */
Array<Postproc> postprocs;
/*! \brief The mutators. */
Array<Mutator> mutators;
/*! \brief The probability of using certain mutator. */
Optional<Map<Mutator, FloatImm>> mutator_probs;
/*! \brief The name of the tuning task. */
String task_name;
/*! \brief The random state. */
Expand All @@ -73,7 +73,7 @@ class TuneContextNode : public runtime::Object {
v->Visit("search_strategy", &search_strategy);
v->Visit("sch_rules", &sch_rules);
v->Visit("postprocs", &postprocs);
v->Visit("mutators", &mutators);
v->Visit("mutator_probs", &mutator_probs);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
Expand Down Expand Up @@ -104,7 +104,7 @@ class TuneContext : public runtime::ObjectRef {
* \param search_strategy The search strategy.
* \param sch_rules The schedule rules.
* \param postprocs The postprocessors.
* \param mutators The mutators.
* \param mutator_probs The probability of using certain mutator.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
Expand All @@ -115,7 +115,7 @@ class TuneContext : public runtime::ObjectRef {
Optional<SearchStrategy> search_strategy, //
Optional<Array<ScheduleRule>> sch_rules, //
Optional<Array<Postproc>> postprocs, //
Optional<Array<Mutator>> mutators, //
Optional<Map<Mutator, FloatImm>> mutator_probs, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
Expand Down
1 change: 1 addition & 0 deletions include/tvm/support/random_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class LinearCongruentialEngine {
* \param rand_state The random state given in result_type.
*/
void Seed(TRandState rand_state = 1) {
ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!";
rand_state %= modulus; // Make sure the seed is within the range of modulus.
if (rand_state == 0)
rand_state = 1; // Avoid getting all 0 given the current parameter set.
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate])
Return
------
result : np.ndarray
The predicted running results.
The predicted normalized score.
"""
n = len(candidates)
results = np.zeros(shape=(n,), dtype="float64")
Expand Down Expand Up @@ -117,7 +117,6 @@ def f_save(path: str) -> None:

@check_override(self.__class__, CostModel)
def f_update(
self,
tune_context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def __init__(
# model-related
if config.nthread is None:
# use physical core number
config._replace(nthread=cpu_count(logical=False))
config = config._replace(nthread=cpu_count(logical=False))
self.config = config
# serialization-related
if path is not None:
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,21 @@ def from_json(json_obj: Any, workload: Workload) -> "TuningRecord":
class Database(Object):
"""The abstract database interface."""

def has_workload(self, mod: IRModule) -> bool:
"""Check if the database has the given workload.

Parameters
----------
mod : IRModule
The IRModule to be searched for.

Returns
-------
result : bool
Whether the database has the given workload.
"""
return _ffi_api.DatabaseHasWorkload(self, mod) # type: ignore # pylint: disable=no-member

def commit_workload(self, mod: IRModule) -> Workload:
"""Commit a workload to the database if missing.

Expand Down Expand Up @@ -207,6 +222,10 @@ class PyDatabase(Database):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Database)
def f_has_workload(mod: IRModule) -> bool:
return self.has_workload(mod)

@check_override(self.__class__, Database)
def f_commit_workload(mod: IRModule) -> Workload:
return self.commit_workload(mod)
Expand All @@ -225,6 +244,7 @@ def f_size() -> int:

self.__init_handle_by_constructor__(
_ffi_api.DatabasePyDatabase, # type: ignore # pylint: disable=no-member
f_has_workload,
f_commit_workload,
f_commit_tuning_record,
f_get_top_k,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate
from .replay_trace import ReplayTrace, ReplayTraceConfig
from .replay_func import ReplayFunc, ReplayFuncConfig
from .evolutionary_search import EvolutionarySearch
106 changes: 106 additions & 0 deletions python/tvm/meta_schedule/search_strategy/evolutionary_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Evolutionary Search Strategy"""

from typing import TYPE_CHECKING, Dict

from tvm._ffi import register_object

from .search_strategy import SearchStrategy
from ..mutator import Mutator
from ..database import Database

from .. import _ffi_api

if TYPE_CHECKING:
from ..cost_model import CostModel


@register_object("meta_schedule.EvolutionarySearch")
class EvolutionarySearch(SearchStrategy):
"""
Replay Trace Search Strategy is a search strategy that always replays the trace by removing its
decisions so that the decisions would be randomly re-generated.

Parameters
----------
num_trials_per_iter : int
Number of trials per iteration.
num_trials_total : int
Total number of trials.
population : int
The initial population of traces from measured samples and randomly generated samples.
max_replay_fail_cnt : int
The maximum number to fail trace replaying.
init_measured_ratio : int
The ratio of measured samples in the initial population.
genetic_algo_iters : int
The number of iterations for genetic algorithm.
max_evolve_fail_cnt : int
The maximum number to retry mutation.
p_mutate : float
The probability of mutation.
eps_greedy : float
The ratio of greedy selected samples in the final picks.
database : Database
The database used in the search.
cost_model : CostModel
The cost model used in the search.
"""

num_trials_per_iter: int
num_trials_total: int
population: int
init_measured_ratio: int
genetic_algo_iters: int
max_replay_fail_cnt: int
max_evolve_fail_cnt: int
p_mutate: float
eps_greedy: float
database: Database
cost_model: "CostModel"

def __init__(
self,
*,
num_trials_per_iter: int,
num_trials_total: int,
database: Database,
cost_model: "CostModel",
population: int = 2048,
max_replay_fail_cnt: int = 64,
init_measured_ratio: float = 0.2,
genetic_algo_iters: int = 10,
max_evolve_fail_cnt: int = 10,
p_mutate: float = 0.85,
eps_greedy: float = 0.25,
):
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
population,
max_replay_fail_cnt,
init_measured_ratio,
genetic_algo_iters,
max_evolve_fail_cnt,
p_mutate,
eps_greedy,
database,
cost_model,
)
Loading