Skip to content

Commit 79cfb79

Browse files
zxybazhjunrushaospectrometerHBHMasterJH5574jinhongyii
authored
[M3c][MetaScheduler] Add ReplayFunc Search Strategy. (#9799)
* Modify TuneContext, TaskScheduler & SearchStrategy functions. Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> * Retrigger CI. * Add ReplayFunc and EvolutionarySearch strategy. Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> * Fix optional task name. Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> * Remove extra files. * Fix things. Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]>
1 parent 94552fc commit 79cfb79

File tree

6 files changed

+335
-10
lines changed

6 files changed

+335
-10
lines changed

python/tvm/meta_schedule/search_strategy/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,7 @@
1919
Meta Schedule search strategy utilizes the design spaces given
2020
to generate measure candidates.
2121
"""
22-
from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy
23-
from .replay_trace import ReplayTrace
22+
23+
from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate
24+
from .replay_trace import ReplayTrace, ReplayTraceConfig
25+
from .replay_func import ReplayFunc, ReplayFuncConfig
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Replay Trace Search Strategy"""
18+
from typing import NamedTuple
19+
20+
from tvm._ffi import register_object
21+
22+
from .. import _ffi_api
23+
from .search_strategy import SearchStrategy
24+
25+
26+
@register_object("meta_schedule.ReplayFunc")
27+
class ReplayFunc(SearchStrategy):
28+
"""
29+
Replay Func Search Strategy is a search strategy that generates measure candidates by
30+
calling a design space generator and transform the design space.
31+
32+
Parameters
33+
----------
34+
num_trials_per_iter : int
35+
Number of trials per iteration.
36+
num_trials_total : int
37+
Total number of trials.
38+
"""
39+
40+
num_trials_per_iter: int
41+
num_trials_total: int
42+
43+
def __init__(
44+
self,
45+
num_trials_per_iter: int,
46+
num_trials_total: int,
47+
):
48+
"""Constructor"""
49+
self.__init_handle_by_constructor__(
50+
_ffi_api.SearchStrategyReplayFunc, # type: ignore # pylint: disable=no-member
51+
num_trials_per_iter,
52+
num_trials_total,
53+
)
54+
55+
56+
class ReplayFuncConfig(NamedTuple):
57+
"""Configuration for ReplayFunc"""
58+
59+
num_trials_per_iter: int
60+
num_trials_total: int
61+
62+
def create_strategy(self) -> ReplayFunc:
63+
return ReplayFunc(self.num_trials_per_iter, self.num_trials_total)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
24+
Mutator Mutator::PyMutator(
25+
PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
26+
PyMutatorNode::FApply f_apply, //
27+
PyMutatorNode::FAsString f_as_string) {
28+
ObjectPtr<PyMutatorNode> n = make_object<PyMutatorNode>();
29+
n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
30+
n->f_apply = std::move(f_apply);
31+
n->f_as_string = std::move(f_as_string);
32+
return Mutator(n);
33+
}
34+
35+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
36+
.set_dispatch<PyMutatorNode>([](const ObjectRef& n, ReprPrinter* p) {
37+
const auto* self = n.as<PyMutatorNode>();
38+
ICHECK(self);
39+
PyMutatorNode::FAsString f_as_string = (*self).f_as_string;
40+
ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!";
41+
p->stream << f_as_string();
42+
});
43+
44+
TVM_REGISTER_OBJECT_TYPE(MutatorNode);
45+
TVM_REGISTER_NODE_TYPE(PyMutatorNode);
46+
47+
TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext")
48+
.set_body_method<Mutator>(&MutatorNode::InitializeWithTuneContext);
49+
TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply")
50+
.set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> Optional<tir::Trace> {
51+
TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom();
52+
return self->Apply(trace, &seed_);
53+
});
54+
TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator);
55+
56+
} // namespace meta_schedule
57+
} // namespace tvm
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
24+
Postproc Postproc::PyPostproc(
25+
PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
26+
PyPostprocNode::FApply f_apply, //
27+
PyPostprocNode::FAsString f_as_string) {
28+
ObjectPtr<PyPostprocNode> n = make_object<PyPostprocNode>();
29+
n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
30+
n->f_apply = std::move(f_apply);
31+
n->f_as_string = std::move(f_as_string);
32+
return Postproc(n);
33+
}
34+
35+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
36+
.set_dispatch<PyPostprocNode>([](const ObjectRef& n, ReprPrinter* p) {
37+
const auto* self = n.as<PyPostprocNode>();
38+
ICHECK(self);
39+
PyPostprocNode::FAsString f_as_string = (*self).f_as_string;
40+
ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!";
41+
p->stream << f_as_string();
42+
});
43+
44+
TVM_REGISTER_OBJECT_TYPE(PostprocNode);
45+
TVM_REGISTER_NODE_TYPE(PyPostprocNode);
46+
47+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext")
48+
.set_body_method<Postproc>(&PostprocNode::InitializeWithTuneContext);
49+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method<Postproc>(&PostprocNode::Apply);
50+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc);
51+
52+
} // namespace meta_schedule
53+
} // namespace tvm
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
24+
/*! \brief A search strategy that generates measure candidates using space generator. */
25+
class ReplayFuncNode : public SearchStrategyNode {
26+
public:
27+
/*! \brief The state of the search strategy. */
28+
struct State {
29+
/*! \brief The search strategy itself */
30+
ReplayFuncNode* self;
31+
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
32+
int st;
33+
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
34+
int ed;
35+
36+
explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {}
37+
38+
inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
39+
inline void NotifyRunnerResults(const Array<RunnerResult>& results);
40+
};
41+
42+
/*! \brief The number of trials per iteration. */
43+
int num_trials_per_iter;
44+
/*! \brief The number of total trials. */
45+
int num_trials_total;
46+
47+
/*! \brief The module to be tuned. */
48+
IRModule mod_{nullptr};
49+
/*! \brief The metadata of the function arguments. */
50+
Array<ArgInfo> args_info_{nullptr};
51+
/*! \brief The post processors */
52+
Array<Postproc> postprocs_{nullptr};
53+
/*! \brief The space generator for measure candidates generation. */
54+
SpaceGenerator space_generator_{nullptr};
55+
/*! \brief The random state. -1 means using random number. */
56+
TRandState rand_state_ = -1;
57+
/*! \brief The state of the search strategy. */
58+
std::unique_ptr<State> state_ = nullptr;
59+
60+
void VisitAttrs(tvm::AttrVisitor* v) {
61+
v->Visit("num_trials_per_iter", &num_trials_per_iter);
62+
v->Visit("num_trials_total", &num_trials_total);
63+
// `space_generator_` is not visited
64+
// `mod_` is not visited
65+
// `args_info_` is not visited
66+
// `num_threads_` is not visited
67+
// `rand_state_` is not visited
68+
// `state_` is not visited
69+
}
70+
71+
static constexpr const char* _type_key = "meta_schedule.ReplayFunc";
72+
TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode);
73+
74+
void InitializeWithTuneContext(const TuneContext& context) final {
75+
this->space_generator_ = context->space_generator.value();
76+
this->mod_ = context->mod.value();
77+
this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
78+
this->postprocs_ = context->postprocs;
79+
this->rand_state_ = ForkSeed(&context->rand_state);
80+
this->state_.reset();
81+
}
82+
83+
void PreTuning(const Array<tir::Schedule>& design_spaces) final {
84+
ICHECK(this->state_ == nullptr);
85+
this->state_ = std::make_unique<State>(this);
86+
}
87+
88+
void PostTuning() final {
89+
ICHECK(this->state_ != nullptr);
90+
this->state_.reset();
91+
}
92+
93+
Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
94+
ICHECK(this->state_ != nullptr);
95+
return this->state_->GenerateMeasureCandidates();
96+
}
97+
98+
void NotifyRunnerResults(const TuneContext& context,
99+
const Array<MeasureCandidate>& measure_candidates,
100+
const Array<RunnerResult>& results) final {
101+
ICHECK(this->state_ != nullptr);
102+
this->state_->NotifyRunnerResults(results);
103+
}
104+
};
105+
106+
inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureCandidates() {
107+
if (st >= self->num_trials_total) {
108+
return NullOpt;
109+
}
110+
ed = std::min(ed, self->num_trials_total);
111+
Array<MeasureCandidate> result;
112+
for (int i = st; i < ed; i++) {
113+
for (;;) {
114+
Array<tir::Schedule> schs = self->space_generator_->GenerateDesignSpace(self->mod_);
115+
int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size());
116+
tir::Schedule sch = schs[design_space_index];
117+
sch->EnterPostproc();
118+
bool failed = false;
119+
for (const Postproc& proc : self->postprocs_) {
120+
if (!proc->Apply(sch)) {
121+
failed = true;
122+
break;
123+
}
124+
}
125+
if (!failed) {
126+
result.push_back(MeasureCandidate(sch, self->args_info_));
127+
break;
128+
}
129+
}
130+
}
131+
return result;
132+
}
133+
134+
inline void ReplayFuncNode::State::NotifyRunnerResults(const Array<RunnerResult>& results) {
135+
st += self->num_trials_per_iter;
136+
ed += self->num_trials_per_iter;
137+
}
138+
139+
SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) {
140+
ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>();
141+
n->num_trials_per_iter = num_trials_per_iter;
142+
n->num_trials_total = num_trials_total;
143+
return SearchStrategy(n);
144+
}
145+
146+
TVM_REGISTER_NODE_TYPE(ReplayFuncNode);
147+
TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc")
148+
.set_body_typed(SearchStrategy::ReplayFunc);
149+
150+
} // namespace meta_schedule
151+
} // namespace tvm

tests/python/unittest/test_meta_schedule_search_strategy.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
# pylint: disable=missing-function-docstring
1919
import sys
2020
import pytest
21-
from typing import List
22-
2321
import tvm
2422
from tvm.meta_schedule import TuneContext
2523
from tvm.meta_schedule.runner import RunnerResult
2624
from tvm.meta_schedule.search_strategy import (
25+
ReplayFunc,
2726
ReplayTrace,
2827
SearchStrategy,
2928
)
@@ -75,17 +74,17 @@ def _schedule_matmul(sch: Schedule):
7574
sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
7675

7776

78-
@pytest.mark.parametrize("TestClass", [ReplayTrace])
77+
@pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace])
7978
def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name
8079
num_trials_per_iter = 7
8180
num_trials_total = 20
8281

8382
strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total)
84-
context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul))
85-
context.space_generator.initialize_with_tune_context(context)
86-
spaces = context.space_generator.generate_design_space(context.mod)
83+
tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul))
84+
tune_context.space_generator.initialize_with_tune_context(tune_context)
85+
spaces = tune_context.space_generator.generate_design_space(tune_context.mod)
8786

88-
strategy.initialize_with_tune_context(context)
87+
strategy.initialize_with_tune_context(tune_context)
8988
strategy.pre_tuning(spaces)
9089
(correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
9190
num_trials_each_iter: List[int] = []
@@ -100,7 +99,7 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl
10099
remove_decisions=(isinstance(strategy, ReplayTrace)),
101100
)
102101
runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
103-
strategy.notify_runner_results(context, candidates, runner_results)
102+
strategy.notify_runner_results(tune_context, candidates, runner_results)
104103
candidates = strategy.generate_measure_candidates()
105104
strategy.post_tuning()
106105
assert num_trials_each_iter == [7, 7, 6]

0 commit comments

Comments
 (0)