Skip to content

Commit e6a321c

Browse files
MasterJH5574junrushaozxybazhspectrometerHBHSiyuan Feng
committed
[MetaSchedule] Post Processor: Rewrite Unbound Block
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
1 parent 7bfb11b commit e6a321c

File tree

5 files changed

+402
-0
lines changed

5 files changed

+402
-0
lines changed

python/tvm/meta_schedule/postproc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
from .postproc import Postproc, PyPostproc
1919
from .disallow_dynamic_loop import DisallowDynamicLoop
2020
from .rewrite_reduction_block import RewriteReductionBlock
21+
from .rewrite_unbound_block import RewriteUnboundBlock
2122
from .verify_gpu_code import VerifyGPUCode
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A postprocessor that adds thread binding to unbound blocks"""
18+
19+
from tvm._ffi.registry import register_object
20+
from .. import _ffi_api
21+
from .postproc import Postproc
22+
23+
24+
@register_object("meta_schedule.RewriteUnboundBlock")
25+
class RewriteUnboundBlock(Postproc):
26+
"""A postprocessor that adds thread binding to unbound blocks"""
27+
28+
def __init__(self) -> None:
29+
self.__init_handle_by_constructor__(
30+
_ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member
31+
)
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace tir {
23+
24+
/*! \brief The rewrite type for an unbound block */
25+
enum class BindType : int32_t {
26+
/*! \brief No additional thread binding is needed */
27+
kNoBind = 0,
28+
/*! \brief Need to bind to blockIdx */
29+
kBindBlock = 1,
30+
/*! \brief Need to bind to both blockIdx and threadIdx */
31+
kBindBlockThread = 2,
32+
};
33+
34+
/*!
35+
* \brief Check the combination of bindings to be added to the block
36+
* \param block_sref The block to be checked
37+
* \param fuse_first_num The number of loops to be fused
38+
* \return The type of binding to be added to the block
39+
*/
40+
BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) {
41+
Array<StmtSRef> loops = tir::GetLoops(block_sref);
42+
int n = loops.size();
43+
if (n == 0) {
44+
return BindType::kNoBind;
45+
}
46+
int i_block_idx = -1;
47+
int i_thread_idx = -1;
48+
int i_multi_child = -1;
49+
int i_spatial_loop = -1;
50+
for (int i = 0; i < n; ++i) {
51+
const StmtSRef& loop_sref = loops[i];
52+
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
53+
runtime::ThreadScope thread_scope = GetThreadScope(loop);
54+
if (IsBlockIdx(thread_scope)) {
55+
if (i_block_idx == -1) {
56+
i_block_idx = i;
57+
}
58+
}
59+
if (IsThreadIdx(thread_scope)) {
60+
if (i_thread_idx == -1) {
61+
i_thread_idx = i;
62+
}
63+
}
64+
if (!IsSingleStmt(loop->body)) {
65+
if (i_multi_child == -1) {
66+
i_multi_child = i + 1;
67+
}
68+
}
69+
if (tir::GetLoopIterType(loop_sref) == IterVarType::kDataPar) {
70+
if (i_spatial_loop == i - 1) {
71+
++i_spatial_loop;
72+
}
73+
}
74+
}
75+
if (i_multi_child == -1) {
76+
i_multi_child = n;
77+
}
78+
if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) {
79+
return BindType::kNoBind;
80+
} else if (i_block_idx != -1 && i_thread_idx == -1) {
81+
ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not";
82+
throw;
83+
} else if (i_block_idx == -1 && i_thread_idx != -1) {
84+
*fuse_first_num = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1);
85+
return BindType::kBindBlock;
86+
} else { // i_block_idx == -1 && i_thread_idx == -1
87+
*fuse_first_num = std::min(i_multi_child, i_spatial_loop + 1);
88+
return BindType::kBindBlockThread;
89+
}
90+
}
91+
92+
/*! \brief Find all the blocks that are not bound */
93+
class UnboundBlockFinder : private StmtVisitor {
94+
public:
95+
static std::vector<std::pair<StmtSRef, String>> Find(const ScheduleState& self) {
96+
UnboundBlockFinder finder(self);
97+
for (const auto& kv : self->mod->functions) {
98+
GlobalVar g_var = kv.first;
99+
BaseFunc base_func = kv.second;
100+
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
101+
finder.global_var_name_ = g_var->name_hint;
102+
finder(Downcast<BlockRealize>(prim_func->body)->block->body);
103+
}
104+
}
105+
return std::move(finder.blocks_);
106+
}
107+
108+
private:
109+
void VisitStmt_(const ForNode* loop) final {
110+
runtime::ThreadScope thread_scope = GetThreadScope(loop);
111+
if (IsBlockIdx(thread_scope)) {
112+
++n_block_idx_;
113+
} else if (IsThreadIdx(thread_scope)) {
114+
++n_thread_idx_;
115+
}
116+
if (n_block_idx_ == 0 || n_thread_idx_ == 0) {
117+
StmtVisitor::VisitStmt_(loop);
118+
}
119+
if (IsBlockIdx(thread_scope)) {
120+
--n_block_idx_;
121+
} else if (IsThreadIdx(thread_scope)) {
122+
--n_thread_idx_;
123+
}
124+
}
125+
126+
void VisitStmt_(const BlockNode* block) final {
127+
blocks_.emplace_back(self_->stmt2ref.at(block), global_var_name_);
128+
}
129+
130+
explicit UnboundBlockFinder(const ScheduleState& self)
131+
: self_{self}, blocks_{}, n_block_idx_{0}, n_thread_idx_{0} {}
132+
133+
/*! \brief The schedule state */
134+
const ScheduleState& self_;
135+
/*! \brief The list of unbound blocks */
136+
std::vector<std::pair<StmtSRef, String>> blocks_;
137+
/*! \brief The number of blockIdx above the current stmt */
138+
int n_block_idx_;
139+
/*! \brief The number of threadIdx above the current stmt */
140+
int n_thread_idx_;
141+
/*! \brief The name of the global var */
142+
String global_var_name_;
143+
};
144+
145+
} // namespace tir
146+
} // namespace tvm
147+
148+
namespace tvm {
149+
namespace meta_schedule {
150+
151+
/*! \brief Add thread binding to unbound blocks */
152+
class RewriteUnboundBlockNode : public PostprocNode {
153+
public:
154+
// Inherited from PostprocNode
155+
void InitializeWithTuneContext(const TuneContext& context) final {
156+
CHECK(context->target.defined()) << "ValueError: target is not defined";
157+
Optional<Integer> warp_size = context->target.value()->GetAttr<Integer>("thread_warp_size");
158+
CHECK(warp_size.defined()) << "ValueError: missing attribute `thread_warp_size` in the target";
159+
this->warp_size_ = warp_size.value();
160+
}
161+
162+
// Inherited from PostprocNode
163+
bool Apply(const tir::Schedule& sch) final;
164+
165+
public:
166+
/*! \brief The cached warp size from Target */
167+
int warp_size_ = -1;
168+
169+
void VisitAttrs(tvm::AttrVisitor* v) {
170+
// `warp_size_` is not visited
171+
}
172+
173+
static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock";
174+
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteUnboundBlockNode, PostprocNode);
175+
};
176+
177+
bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) {
178+
using tir::BlockRV;
179+
using tir::LoopRV;
180+
using tir::Schedule;
181+
ICHECK_NE(this->warp_size_, -1);
182+
std::vector<std::pair<tir::StmtSRef, String>> unbound_blocks =
183+
tir::UnboundBlockFinder::Find(sch->state());
184+
for (const auto& kv : unbound_blocks) {
185+
tir::StmtSRef block_sref = kv.first;
186+
String global_var_name = kv.second;
187+
int fuse_first_num = 0;
188+
tir::BindType bind_type = tir::GetBindType(block_sref, &fuse_first_num);
189+
if (bind_type == tir::BindType::kNoBind) {
190+
continue;
191+
}
192+
BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
193+
Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
194+
LoopRV fused = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + fuse_first_num});
195+
if (bind_type == tir::BindType::kBindBlock) {
196+
sch->Bind(fused, "blockIdx.x");
197+
} else if (bind_type == tir::BindType::kBindBlockThread) {
198+
Array<LoopRV> splits = sch->Split(fused, {NullOpt, Integer(this->warp_size_)});
199+
ICHECK_EQ(splits.size(), 2);
200+
sch->Bind(splits[0], "blockIdx.x");
201+
sch->Bind(splits[1], "threadIdx.x");
202+
}
203+
}
204+
return true;
205+
}
206+
207+
Postproc Postproc::RewriteUnboundBlock() {
208+
ObjectPtr<RewriteUnboundBlockNode> n = make_object<RewriteUnboundBlockNode>();
209+
n->warp_size_ = -1;
210+
return Postproc(n);
211+
}
212+
213+
TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode);
214+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock")
215+
.set_body_typed(Postproc::RewriteUnboundBlock);
216+
217+
} // namespace meta_schedule
218+
} // namespace tvm

src/tir/schedule/utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,18 @@ inline Array<Stmt> AsArray(const Stmt& stmt) {
193193
return {stmt};
194194
}
195195

196+
/*!
197+
* \brief Checks of a statement is a SeqStmt that contains multiple statements
198+
* \param stmt The statement to be checked
199+
* \return A boolean indicating the result
200+
*/
201+
inline bool IsSingleStmt(const Stmt& stmt) {
202+
if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
203+
return seq_stmt->seq.size() == 1;
204+
}
205+
return true;
206+
}
207+
196208
/******** IterVar ********/
197209

198210
/*!

0 commit comments

Comments
 (0)