Skip to content

Commit 1ac01b4

Browse files
jinhongyiijunrushaozxybazhspectrometerHBHSiyuan Feng
authored
[MetaSchedule] Schedule Rule: Cross Thread Reduction (#9994)
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
1 parent 81b66e6 commit 1ac01b4

File tree

9 files changed

+605
-4
lines changed

9 files changed

+605
-4
lines changed

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ class ScheduleRule : public runtime::ObjectRef {
162162
*/
163163
TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
164164
Optional<Integer> max_innermost_factor);
165+
/*!
166+
* \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks
167+
* correspondingly when needed
168+
* \param thread_extents Candidates of thread axis extent (values are required to be positive).
169+
* \return The schedule rule created
170+
*/
171+
TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
165172
/*!
166173
* \brief A rule that randomly select a compute-at location for a free block
167174
* \return The rule created

python/tvm/meta_schedule/schedule_rule/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
"""
1919
from .add_rfactor import AddRFactor
2020
from .auto_inline import AutoInline
21+
from .cross_thread_reduction import CrossThreadReduction
2122
from .schedule_rule import PyScheduleRule, ScheduleRule
2223
from .random_compute_location import RandomComputeLocation
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Rules which apply cross-thread reduction to some reduction blocks correspondingly when needed"""
18+
from typing import List
19+
20+
from tvm._ffi import register_object
21+
22+
from .. import _ffi_api
23+
from .schedule_rule import ScheduleRule
24+
25+
26+
@register_object("meta_schedule.CrossThreadReduction")
27+
class CrossThreadReduction(ScheduleRule):
28+
"""A schedule rule which applies cross-thread reduction to some reduction blocks
29+
correspondingly when needed
30+
31+
Parameters
32+
----------
33+
thread_extents: List[int]
34+
Candidates of thread axis extent (values are required to be positive).
35+
"""
36+
37+
def __init__(self, thread_extents: List[int]) -> None:
38+
self.__init_handle_by_constructor__(
39+
_ffi_api.ScheduleRuleCrossThreadReduction, # type: ignore # pylint: disable=no-member
40+
thread_extents,
41+
)

python/tvm/meta_schedule/testing/schedule_rule.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tvm.meta_schedule.schedule_rule import (
1919
AddRFactor,
2020
AutoInline,
21+
CrossThreadReduction,
2122
ScheduleRule,
2223
)
2324
from tvm.target import Target
@@ -53,3 +54,10 @@ def add_rfactor(target: Target) -> ScheduleRule:
5354
if target.kind.name == "llvm":
5455
return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64)
5556
raise NotImplementedError(f"{target.kind.name} is not supported")
57+
58+
59+
def cross_thread_reduction(target: Target) -> ScheduleRule:
60+
"""Default schedule rules for with cross-thread reduction"""
61+
if target.kind.name == "cuda":
62+
return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
63+
raise NotImplementedError(f"{target.kind.name} is not supported")
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
24+
class CrossThreadReductionNode : public ScheduleRuleNode {
25+
public:
26+
// Inherited from ScheduleRuleNode
27+
void InitializeWithTuneContext(const TuneContext& context) final {
28+
ICHECK(context->target.defined());
29+
Target target = context->target.value();
30+
31+
Optional<Integer> opt_max_threads_per_block = target->GetAttr<Integer>("max_threads_per_block");
32+
Optional<Integer> opt_warp_size = target->GetAttr<Integer>("thread_warp_size");
33+
34+
if (!opt_max_threads_per_block.defined()) {
35+
LOG(WARNING) << "Target does not have attribute \"max_threads_per_block\", therefore the "
36+
"rule CrossThreadReduction will not be applied";
37+
}
38+
if (!opt_warp_size.defined()) {
39+
LOG(WARNING) << "Target does not have attribute \"thread_warp_size\", therefore the rule "
40+
"CrossThreadReduction will not be applied";
41+
}
42+
max_threads_per_block = opt_max_threads_per_block.value_or(Integer(-1))->value;
43+
warp_size = opt_warp_size.value_or(Integer(-1))->value;
44+
}
45+
46+
// Inherited from ScheduleRuleNode
47+
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
48+
// Step 0. Check the conditions of this rule.
49+
if (max_threads_per_block == -1 || warp_size == -1) {
50+
return {sch};
51+
}
52+
const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
53+
if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_threads_per_block,
54+
warp_size)) {
55+
return {sch};
56+
}
57+
58+
// Step 1. Make a copy of the original schedule. The new copy is used for scheduling.
59+
tir::Schedule tmp_sch = sch->Copy();
60+
tmp_sch->Seed(sch->ForkSeed());
61+
62+
// Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the
63+
// block to its consumers. We want to fuse as much as possible because it results in
64+
// significantly faster schedule.
65+
bool fusible = false;
66+
// `target_loop` is the loop position where the input block will be computed at.
67+
tir::LoopRV target_loop{nullptr};
68+
// `target_block` is the consumer block that we want to compute-at the input block to.
69+
tir::BlockRV target_block{nullptr};
70+
// `tgt_block_innermost_loop` is the innermost loop outside the target block.
71+
tir::LoopRV tgt_block_innermost_loop{nullptr};
72+
73+
std::tie(fusible, target_loop, target_block, tgt_block_innermost_loop) =
74+
GetComputeTargetLoopAndBlock(tmp_sch, block_rv);
75+
76+
// Step 3. Try block fusion.
77+
int n_candidate = static_cast<int>(thread_extents.size());
78+
Array<FloatImm> probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate));
79+
tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs);
80+
if (fusible) {
81+
ICHECK(target_block.defined());
82+
ICHECK(target_loop.defined());
83+
84+
// Step 3.1.
85+
// - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first
86+
// bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split
87+
// the loop before binding.
88+
// - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor.
89+
if (!InThreadScope(tmp_sch, target_block)) {
90+
const Array<tir::LoopRV>& split_res =
91+
tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, thread_extent});
92+
tmp_sch->Bind(split_res[1], "threadIdx.x");
93+
if (tgt_block_innermost_loop.same_as(target_loop)) {
94+
target_loop = split_res[0];
95+
}
96+
} else {
97+
thread_extent = GetThreadIdxExtentFromTrace(tmp_sch->trace().value());
98+
}
99+
// Step 3.2. Do the compute-at.
100+
tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true);
101+
// Step 3.3. Set the storage scope of the output buffer to shared memory.
102+
tmp_sch->SetScope(block_rv, /*buffer_index=*/0, /*storage_scope=*/"shared");
103+
}
104+
105+
// Step 4. Reorder the loop axes if reduction loops are not innermost. After the reordering,
106+
// fuse all the reduction loops.
107+
size_t num_spatial_loops;
108+
tir::LoopRV fused_reduce_loop;
109+
ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops);
110+
// Step 5. Split the fused reduction loop and bind the inner one to threadIdx.
111+
const Array<tir::LoopRV>& split_res =
112+
tmp_sch->Split(fused_reduce_loop, {NullOpt, thread_extent});
113+
tmp_sch->Bind(split_res[1], "threadIdx.x");
114+
115+
return {tmp_sch, sch};
116+
}
117+
118+
private:
119+
/*!
120+
* \brief Check whether the input block is in thread scope, i.e., some of its outer loop is
121+
* bound to threadIdx.
122+
* \param sch The TensorIR schedule
123+
* \param block The block to be checked
124+
* \return A boolean indicating whether the block is in thread scope.
125+
*/
126+
bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) {
127+
const Array<tir::LoopRV>& axes = sch->GetLoops(block);
128+
for (const tir::LoopRV& loop_rv : axes) {
129+
const tir::For& loop = sch->Get(loop_rv);
130+
runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get());
131+
if (tir::IsThreadIdx(thread_scope)) {
132+
return true;
133+
}
134+
}
135+
return false;
136+
}
137+
138+
/*!
139+
* \brief Get the ExprRV which used to define the extent of a given loop.
140+
* \param trace The trace of the schedule, where the extent is to be found
141+
* \param loop The loop whose extent is to be found
142+
* \param extent The finding result
143+
* \return Whether the find is successful.
144+
*/
145+
bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop,
146+
tir::ExprRV* extent) {
147+
for (const tir::Instruction& inst : trace->insts) {
148+
if (inst->kind->name == "Split") {
149+
int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin();
150+
CHECK(inst->inputs[1 + i].defined())
151+
<< "ValueError: Extracting an extent which needs inference is not supported so far";
152+
*extent = Downcast<tir::ExprRV>(inst->inputs[1 + i]);
153+
return true;
154+
}
155+
}
156+
return false;
157+
}
158+
159+
/*!
160+
* \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace.
161+
* \param trace The trace of the schedule, where the extent is to be found
162+
* \return The extent of "threadIdx.x" in the input schedule
163+
*/
164+
tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) {
165+
tir::ExprRV extent{nullptr};
166+
for (const tir::Instruction& inst : trace->insts) {
167+
if (inst->kind->name == "Bind" && Downcast<String>(inst->attrs[0]) == "threadIdx.x") {
168+
if (GetLoopRVExtentSource(trace, Downcast<tir::LoopRV>(inst->inputs[0]), &extent)) {
169+
return extent;
170+
}
171+
}
172+
}
173+
CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\"";
174+
throw;
175+
}
176+
177+
/*!
178+
* \brief Get the compute-at target loop and the first block under the target loop.
179+
* \param sch The TensorIR schedule
180+
* \param block_rv The block whose compute-at target loop is queried
181+
* \return A tuple consisting of
182+
* 1. a boolean indicating whether the block can be computed at some target loop (a.k.a. fusible);
183+
* 2. the compute-at target loop when fusible, or a null loop random variable;
184+
* 3. the first block under the target loop when fusible, or a null block random variable;
185+
* 4. the innermost loop outside the target block when fusible, or a null block random variable.
186+
*/
187+
std::tuple<bool, tir::LoopRV, tir::BlockRV, tir::LoopRV> GetComputeTargetLoopAndBlock(
188+
const tir::Schedule& sch, const tir::BlockRV& block_rv) {
189+
// Step 1. Get all the consumers of the input block.
190+
Array<tir::BlockRV> consumers = sch->GetConsumers(block_rv);
191+
192+
// Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is
193+
// not fusible.
194+
if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) {
195+
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
196+
tir::LoopRV{nullptr});
197+
}
198+
199+
// Step 3. Calculate the lowest common ancestor of all the consumers.
200+
// - If the lowest common ancestor is a block:
201+
// - if there is only one consumer, the target block is that consumer;
202+
// - if there are multiple consumers, they must not share a common loop, and the case is not
203+
// fusible;
204+
// - If the lowest common ancestor is a loop, the target block is also the first consumer.
205+
const tir::StmtSRef& lca_sref =
206+
tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers));
207+
if (consumers.size() > 1 && lca_sref->StmtAs<tir::BlockNode>() != nullptr) {
208+
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
209+
tir::LoopRV{nullptr});
210+
}
211+
212+
// Step 4. Get the outer loops of the target block, and get the compute-at position index.
213+
Array<tir::LoopRV> tgt_block_loops = sch->GetLoops(consumers[0]);
214+
int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref);
215+
216+
// Step 5. A negative position index means not fusible, and vice-versa.
217+
if (pos < 0) {
218+
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
219+
tir::LoopRV{nullptr});
220+
} else {
221+
return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back());
222+
}
223+
}
224+
225+
/*!
226+
* \brief Get the compute-at position index of the input block, according to
227+
* 1. the loops outside the input block;
228+
* 2. the loops outside the target block;
229+
* 3. the lowest common ancestor of all the consumers of the input block.
230+
* \param sch The TensorIR schedule
231+
* \param block_loops The loops outside the input block
232+
* \param tgt_block_loops The loops outside the target block
233+
* \param lca_sref The lowest common ancestor of all the consumers of the input block
234+
* \return The compute-at position index of the input block
235+
*/
236+
int GetComputePosition(const tir::Schedule& sch, const Array<tir::LoopRV>& block_loops,
237+
const Array<tir::LoopRV>& tgt_block_loops, const tir::StmtSRef& lca_sref) {
238+
int n_block_loop = static_cast<int>(block_loops.size());
239+
int n_tgt_block_loop = static_cast<int>(tgt_block_loops.size());
240+
241+
for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) {
242+
if (tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) {
243+
return i - 1;
244+
} else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) {
245+
// If the lowest common ancestor is a loop, the compute location of the input block should
246+
// not be deeper than the LCA loop.
247+
return i;
248+
}
249+
}
250+
return std::min(n_block_loop, n_tgt_block_loop) - 1;
251+
}
252+
253+
public:
254+
/*! \brief The maximum number of threads allowed in a thread block */
255+
int max_threads_per_block;
256+
/*! \brief The number of threads per warp */
257+
int warp_size;
258+
/*! \brief Candidates of thread axis extent (values are required to be positive). */
259+
Array<Integer> thread_extents;
260+
261+
void VisitAttrs(tvm::AttrVisitor* v) {
262+
v->Visit("max_threads_per_block", &max_threads_per_block);
263+
v->Visit("warp_size", &warp_size);
264+
v->Visit("thread_extents", &thread_extents);
265+
}
266+
267+
static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction";
268+
TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode);
269+
};
270+
271+
ScheduleRule ScheduleRule::CrossThreadReduction(Array<Integer> thread_extents) {
272+
for (const Integer& extent : thread_extents) {
273+
CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive";
274+
}
275+
ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>();
276+
n->thread_extents = std::move(thread_extents);
277+
return ScheduleRule(n);
278+
}
279+
280+
TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode);
281+
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction")
282+
.set_body_typed(ScheduleRule::CrossThreadReduction);
283+
284+
} // namespace meta_schedule
285+
} // namespace tvm

src/tir/schedule/analysis/analysis.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,8 +1788,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
17881788
const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, //
17891789
/*require_stage_pipeline=*/false, //
17901790
/*require_subtree_compact_dataflow=*/false);
1791-
if (!(IsReductionBlock(self, block_sref, scope_sref) && //
1792-
IsTrivialBinding(self, block_sref))) {
1791+
if (!IsReductionBlock(self, block_sref, scope_sref) //
1792+
|| !IsTrivialBinding(self, block_sref) //
1793+
|| HasBeenMultiLevelTiled(block_sref)) {
17931794
return false;
17941795
}
17951796

src/tir/schedule/trace.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,10 @@ Trace TraceNode::Simplified(bool remove_postproc) const {
448448
}
449449
// Add its inputs as "used" ones
450450
for (const ObjectRef& obj : inst->inputs) {
451-
if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() ||
452-
obj->IsInstance<VarNode>()) {
451+
if (!obj.defined()) {
452+
continue;
453+
} else if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() ||
454+
obj->IsInstance<VarNode>()) {
453455
used_rvs.insert(obj.get());
454456
continue;
455457
} else if (obj->IsInstance<PrimExprNode>()) {

0 commit comments

Comments
 (0)