Skip to content

Commit 09bd3cc

Browse files
Siyuan FengjunrushaozxybazhspectrometerHBHMasterJH5574
committed
[MetaSchedule] Mutator: Mutate compute location
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
1 parent e9ee73f commit 09bd3cc

File tree

4 files changed

+248
-0
lines changed

4 files changed

+248
-0
lines changed

python/tvm/meta_schedule/mutator/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
design space.
2121
"""
2222
from .mutator import Mutator, PyMutator
23+
from .mutate_compute_location import MutateComputeLocation
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A mutator that mutates the compute-at location decision of SampleComputeLocation"""
18+
from tvm._ffi.registry import register_object
19+
20+
from .. import _ffi_api
21+
from .mutator import Mutator
22+
23+
24+
@register_object("meta_schedule.MutateComputeLocation")
25+
class MutateComputeLocation(Mutator):
26+
"""A mutator that mutates the compute-at location decision of SampleComputeLocation"""
27+
28+
def __init__(self) -> None:
29+
self.__init_handle_by_constructor__(
30+
_ffi_api.MutatorMutateComputeLocation, # type: ignore # pylint: disable=no-member
31+
)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
24+
using tir::Instruction;
25+
using tir::InstructionKind;
26+
using tir::Trace;
27+
28+
/*! \brief A mutator that mutates the compute-at location decision of SampleComputeLocation */
29+
class MutateComputeLocationNode : public MutatorNode {
30+
public:
31+
/*! \brief JSON representation of the workload */
32+
std::string json_mod_;
33+
34+
void VisitAttrs(tvm::AttrVisitor* v) {}
35+
static constexpr const char* _type_key = "meta_schedule.MutateComputeLocation";
36+
TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode);
37+
38+
public:
39+
// Inherit from `MutatorNode`
40+
void InitializeWithTuneContext(const TuneContext& context) final {
41+
this->json_mod_ = SaveJSON(context->mod.value());
42+
}
43+
// Inherit from `MutatorNode`
44+
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
45+
46+
private:
47+
struct Candidate {
48+
/*! \brief The SampleComputeLocation instruction */
49+
Instruction inst;
50+
/*! \brief The candidate compute-at locations */
51+
std::vector<int> locs;
52+
53+
explicit Candidate(Instruction inst, std::vector<int> locs)
54+
: inst(std::move(inst)), locs(std::move(locs)) {}
55+
};
56+
57+
std::vector<Candidate> FindCandidates(const Trace& trace, TRandState* rand_state);
58+
};
59+
60+
/*!
61+
* \brief Find all appearances of instruction `SampleComputeLocation` whose decision can be mutated
62+
* to at lease one other value
63+
* \param trace The trace from which to find the instructions
64+
* \return All the candidate instructions together with the candidate compute-at locations
65+
*/
66+
std::vector<MutateComputeLocationNode::Candidate> MutateComputeLocationNode::FindCandidates(
67+
const Trace& trace, TRandState* rand_state) {
68+
tir::Schedule sch = tir::Schedule::Traced( //
69+
/*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), //
70+
/*rand_state=*/ForkSeed(rand_state), //
71+
/*debug_mode=*/0, //
72+
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
73+
74+
static InstructionKind inst_sample_compute_location =
75+
InstructionKind::Get("SampleComputeLocation");
76+
std::vector<MutateComputeLocationNode::Candidate> candidates;
77+
78+
auto f_decision_provider = [&](const tir::Instruction& inst, //
79+
const Array<ObjectRef>& inputs, //
80+
const Array<ObjectRef>& attrs, //
81+
const ObjectRef& decision) -> ObjectRef {
82+
if (inst->kind.same_as(inst_sample_compute_location)) {
83+
// Step 1. Extract the instruction input and the old decision.
84+
ICHECK_EQ(inputs.size(), 1);
85+
tir::StmtSRef block_sref = sch->GetSRef(Downcast<tir::BlockRV>(inputs[0]));
86+
int old_decision = Downcast<Integer>(decision)->value;
87+
88+
// Step 2. Collect all the compute_at locations.
89+
Array<tir::StmtSRef> location_srefs;
90+
std::vector<int> location_indices;
91+
std::tie(location_srefs, location_indices) = CollectComputeLocation(sch->state(), block_sref);
92+
// Step 3. Remove the old decision.
93+
auto it = std::find(location_indices.begin(), location_indices.end(), old_decision);
94+
if (it != location_indices.end()) {
95+
location_srefs.erase(location_srefs.begin() + (it - location_indices.begin()));
96+
location_indices.erase(it);
97+
}
98+
ICHECK_EQ(location_srefs.size(), location_indices.size());
99+
// Step 4. Add a new candidate if there are at least one remaining compute-at position.
100+
if (!location_srefs.empty()) {
101+
candidates.emplace_back(inst, std::move(location_indices));
102+
}
103+
}
104+
return decision;
105+
};
106+
trace->ApplyToSchedule(sch, //
107+
/*remove_postproc=*/true, //
108+
/*decision_provider=*/f_decision_provider);
109+
return candidates;
110+
}
111+
112+
Optional<Trace> MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) {
113+
std::vector<Candidate> candidates = FindCandidates(trace, rand_state);
114+
if (candidates.empty()) {
115+
return NullOpt;
116+
}
117+
const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())];
118+
int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())];
119+
return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true);
120+
}
121+
122+
Mutator Mutator::MutateComputeLocation() {
123+
return Mutator(make_object<MutateComputeLocationNode>());
124+
}
125+
126+
TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode);
127+
TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation")
128+
.set_body_typed(Mutator::MutateComputeLocation);
129+
130+
} // namespace meta_schedule
131+
} // namespace tvm
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
18+
from tvm.meta_schedule import TuneContext
19+
from tvm.meta_schedule.mutator import MutateComputeLocation, Mutator
20+
from tvm.script import tir as T
21+
from tvm.target import Target
22+
from tvm.tir import Schedule
23+
24+
# pylint: disable=invalid-name, no-member
25+
26+
27+
@T.prim_func
28+
def add(a: T.handle, b: T.handle) -> None:
29+
# function attr dict
30+
T.func_attr({"global_symbol": "main"})
31+
A = T.match_buffer(a, [2048, 2048, 2048], dtype="float32")
32+
B = T.match_buffer(b, [2048, 2048, 2048], dtype="float32")
33+
A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32")
34+
# body
35+
for i, j, k in T.grid(2048, 2048, 2048):
36+
with T.block("move"):
37+
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
38+
T.reads([A[vi, vj, vk]])
39+
T.writes([A_cached[vi, vj, vk]])
40+
A_cached[vi, vj, vk] = A[vi, vj, vk]
41+
for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32):
42+
with T.block("add"):
43+
vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2)
44+
vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2)
45+
vk = T.axis.spatial(2048, k0 * 32 + k1)
46+
T.reads([A_cached[vi, vj, vk]])
47+
T.writes([B[vi, vj, vk]])
48+
B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1)
49+
50+
51+
# pylint: enable=invalid-name, no-member
52+
53+
54+
def _sch(decision: int) -> Schedule:
55+
sch = Schedule(add, debug_mask="all")
56+
# pylint: disable=invalid-name
57+
b0 = sch.get_block(name="move", func_name="main")
58+
l1 = sch.sample_compute_location(block=b0, decision=decision)
59+
sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)
60+
# pylint: enable=invalid-name
61+
return sch
62+
63+
64+
def _make_mutator(target: Target) -> Mutator:
65+
mutator = MutateComputeLocation()
66+
mutator.initialize_with_tune_context(TuneContext(mod=add, target=target))
67+
return mutator
68+
69+
70+
def test_mutate_compute_location_add():
71+
mutator = _make_mutator(
72+
target=Target("llvm"),
73+
)
74+
sch = _sch(decision=4)
75+
results = set()
76+
for _ in range(100):
77+
trace = mutator.apply(sch.trace)
78+
decision = trace.decisions[trace.insts[-2]]
79+
assert not decision == 4
80+
results.add(decision)
81+
assert len(results) == 9
82+
83+
84+
if __name__ == "__main__":
85+
test_mutate_compute_location_add()

0 commit comments

Comments
 (0)