Skip to content

Commit 5779bf9

Browse files
Siyuan FengjunrushaozxybazhspectrometerHBHMasterJH5574
committed
[MetaSchedule] random compute location
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
1 parent 7485413 commit 5779bf9

File tree

19 files changed

+709
-0
lines changed

19 files changed

+709
-0
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@ class ScheduleNode : public runtime::Object {
210210
*/
211211
virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
212212
Optional<Array<Integer>> decision = NullOpt) = 0;
213+
/*!
214+
* \brief Sample a compute-at location of the given block
215+
* \param block_rv The block whose compute-at location is to be sampled
216+
* \param decision The sampling decision
217+
* \return The sampled loop where the input block is to be computed at
218+
*/
219+
virtual LoopRV SampleComputeLocation(const BlockRV& block_rv,
220+
Optional<Integer> decision = NullOpt) = 0;
213221

214222
/******** Schedule: Get blocks & loops ********/
215223
/*!

include/tvm/tir/stmt.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,13 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
13611361
*/
13621362
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
13631363

1364+
/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
1365+
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1366+
1367+
/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */
1368+
constexpr const char* meta_schedule_random_compute_producer =
1369+
"meta_schedule.random_compute_producer";
1370+
13641371
/*!
13651372
* \brief Check if attr_key is a pragma key extension
13661373
* \param attr_key The attr key to be compared

python/tvm/meta_schedule/schedule_rule/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
blocks in a schedule. See also PostOrderApply.
1818
"""
1919
from .schedule_rule import PyScheduleRule, ScheduleRule
20+
from .random_compute_location import RandomComputeLocation
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Rule that randomly select a compute-at location for a free block"""
18+
from tvm._ffi import register_object
19+
20+
from .. import _ffi_api
21+
from .schedule_rule import ScheduleRule
22+
23+
24+
@register_object("meta_schedule.RandomComputeLocation")
25+
class RandomComputeLocation(ScheduleRule):
26+
"""A rule that randomly select a compute-at location for a free block"""
27+
28+
def __init__(self) -> None:
29+
self.__init_handle_by_constructor__(
30+
_ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member
31+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
18+
from typing import List
19+
20+
from tvm.tir import Schedule
21+
from tvm.tir.schedule import Trace
22+
23+
24+
def check_trace(spaces: List[Schedule], expected: List[List[str]]):
25+
expected_traces = {"\n".join(t) for t in expected}
26+
actual_traces = set()
27+
for space in spaces:
28+
trace = Trace(space.trace.insts, {})
29+
trace = trace.simplified(remove_postproc=True)
30+
str_trace = "\n".join(str(trace).strip().splitlines())
31+
actual_traces.add(str_trace)
32+
assert str_trace in expected_traces, "\n" + str_trace
33+
assert len(expected_traces) == len(actual_traces)

python/tvm/tir/schedule/schedule.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,32 @@ def sample_perfect_tile(
369369
)
370370
)
371371

372+
@type_checked
373+
def sample_compute_location(
374+
self,
375+
block: BlockRV,
376+
decision: Optional[int] = None,
377+
) -> LoopRV:
378+
"""Sample a compute-at location of the given block
379+
380+
Parameters
381+
----------
382+
block : BlockRV
383+
The block whose compute-at location is to be sampled
384+
decision : Optional[int]
385+
The sampling decision
386+
387+
Returns
388+
-------
389+
result : LoopRV
390+
The sampled loop where the input block is to be computed at
391+
"""
392+
return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member
393+
self,
394+
block,
395+
decision,
396+
)
397+
372398
########## Schedule: Get blocks & loops ##########
373399
@type_checked
374400
def get_block(
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
24+
class RandomComputeLocationNode : public ScheduleRuleNode {
25+
public:
26+
// Inherited from ScheduleRuleNode
27+
void InitializeWithTuneContext(const TuneContext& context) final {}
28+
29+
// Inherited from ScheduleRuleNode
30+
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
31+
if (!CheckConditions(sch, block_rv)) {
32+
return {sch};
33+
}
34+
35+
// Step 1. If the producer of the input block needs a random compute-at location (specified by
36+
// the annotation), we collect the producer first, and transform the producer block later.
37+
// - The reason we collect the producer before transforming the input block is that, if the
38+
// decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
39+
// access the input block. Hence we collect its producer ahead of time.
40+
// - Note that only single producer is allowed in this case.
41+
Array<tir::BlockRV> producers{nullptr};
42+
if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
43+
true)) {
44+
producers = sch->GetProducers(block_rv);
45+
sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
46+
ICHECK_EQ(producers.size(), 1);
47+
}
48+
49+
// Step 2. Transform the input block.
50+
tir::Schedule res = RandomlyComputeAt(sch, block_rv);
51+
52+
// Step 3. Transform the producer block if compute-location sampling is needed.
53+
if (producers.defined()) {
54+
res = RandomlyComputeAt(res, producers[0]);
55+
}
56+
57+
return {res};
58+
}
59+
60+
private:
61+
bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
62+
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
63+
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
64+
65+
// Cond 1. The block is not the root block.
66+
if (block_sref->parent == nullptr) {
67+
return false;
68+
}
69+
// Cond 2. The block should be the direct child block of the root block.
70+
if (GetScopeRoot(sch->state(), block_sref, //
71+
/*require_stage_pipeline=*/false, //
72+
/*require_subtree_compact_dataflow=*/false)
73+
->parent != nullptr) {
74+
return false;
75+
}
76+
// Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
77+
// block.
78+
Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
79+
if (loop_srefs.empty()) {
80+
return false;
81+
}
82+
if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
83+
return false;
84+
}
85+
// Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
86+
if (tir::HasBeenMultiLevelTiled(block_sref)) {
87+
return false;
88+
}
89+
// Cond 6. The block has at lease one consumer.
90+
if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
91+
return false;
92+
}
93+
return true;
94+
}
95+
96+
/*!
97+
* \brief Keep sampling a compute-at location for the input block until success.
98+
* \param sch The TIR schedule
99+
* \param block_rv The block whose compute-at location is to be sampled
100+
* \return The TIR schedule after transformation
101+
*/
102+
tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
103+
for (;;) {
104+
tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
105+
try {
106+
sch->ComputeAt(block_rv, compute_at_loc, true);
107+
} catch (const dmlc::Error& e) {
108+
// ComputeAt fails, cleanup the following before re-try:
109+
// 1) trace: instruction & decisions
110+
// 2) sym_tab
111+
sch->trace().value()->Pop();
112+
sch->RemoveRV(compute_at_loc);
113+
continue;
114+
}
115+
break;
116+
}
117+
return sch;
118+
}
119+
120+
public:
121+
void VisitAttrs(tvm::AttrVisitor* v) {}
122+
123+
static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation";
124+
TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode);
125+
};
126+
127+
ScheduleRule ScheduleRule::RandomComputeLocation() {
128+
return ScheduleRule(make_object<RandomComputeLocationNode>());
129+
}
130+
131+
TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode);
132+
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation")
133+
.set_body_typed(ScheduleRule::RandomComputeLocation);
134+
} // namespace meta_schedule
135+
} // namespace tvm

src/tir/schedule/analysis.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,39 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
266266
*/
267267
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);
268268

269+
/*!
270+
* \brief Get the IterVarType of the specific loop, according to the blocks it's bound to
271+
* \param loop_sref The loop to be checked
272+
* \return The IterVarType of the specific loop
273+
*/
274+
IterVarType GetLoopIterType(const StmtSRef& loop_sref);
275+
276+
/*!
277+
* \brief Get the lowest common ancestor of an array of blocks or loops on the sref tree
278+
* \param srefs The block srefs or loop srefs whose lowest common ancestor is to be queried
279+
* \return The lowest common ancestor of the input block srefs or loop srefs
280+
* \note The input array is required to have at least one sref
281+
*/
282+
StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs);
283+
284+
/*!
285+
* \brief Checks if the given block has been applied by multi-level tiling. We check this by
286+
* examine the block's annotation.
287+
* \param block_sref The block to be checked
288+
* \return A boolean indicating whether the block has been multi-level tiled.
289+
*/
290+
bool HasBeenMultiLevelTiled(const StmtSRef& block_sref);
291+
292+
/*!
293+
* \brief Collect all the feasible compute-at locations of the input block
294+
* \param self The schedule state
295+
* \param block_sref The block whose compute-at locations are to be collected
296+
* \return All the feasible compute-at locations of the input block, given as an array of loop srefs
297+
* and an array of their indices among the outer loops of the input block
298+
*/
299+
std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const ScheduleState& self,
300+
const StmtSRef& block_sref);
301+
269302
/******** Producer-consumer relation ********/
270303

271304
/*!

0 commit comments

Comments
 (0)