Skip to content

Commit b51bee8

Browse files
committed
[TIR][Transform] Add LiftThreadBinding Pass
This PR adds a pass LiftThreadBinding to TIR. Previously, during GPU cross-thread reduction, a temporary local buffer will be created in the RF buffer, as a concrete example: ```python rf_local = T.alloc_buffer(..., scope="local") // Step 1. Data parallel RF block for tx in T.thread_binding(..., thread="threadIdx.x") rf_local[tx, ...] = // Step 2. Cross-thread reduction to accumuate rf_local for ...: for tx' in T.thread_binding(..., thread="threadIdx.x"): ... += rf_local[tx', ...] ``` In this case, the buffer `rf_local` will only be accessed by a single point `tx` or `tx'`, but during the pass `CompactBuffeRegion`, the two variables as thread bindings are treated as two separate variables, i.e. the information that `tx` and `tx'` are always identical to each other is discarded, which means the accessed region on `rf_local` are estimated as `Union({tx}, {tx'})` as opposed to `{tx}`, leading over allocation of local registers. This pass is introduced to address this issue by lifting thread bindings to their LCAs.
1 parent bcb92cf commit b51bee8

File tree

8 files changed

+365
-2
lines changed

8 files changed

+365
-2
lines changed

include/tvm/tir/transform.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,12 @@ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
482482
*/
483483
TVM_DLL Pass ConvertBlocksToOpaque();
484484

485+
/*!
486+
* \brief Lift the same thread bindings to their LCA loops
487+
* \return The pass.
488+
*/
489+
TVM_DLL Pass LiftThreadBinding();
490+
485491
/*!
486492
* \brief Compact the buffer access region by removing the buffer regions that are not accessed,
487493
* i.e. narrowing the buffer shape and adjust the access region if necessary.

include/tvm/tir/var.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ class Var : public PrimExpr {
103103
* \param span The location of this object in the source code.
104104
*/
105105
TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span());
106+
/*!
107+
* \brief Make a new copy of var with same type, but a different nam
108+
* \param name The new name to be used.
109+
* \return the new Var copy
110+
*/
111+
TVM_DLL Var copy_with_name(const String& name) const;
106112
/*!
107113
* \brief Make a new copy of var with same type, append suffix
108114
* \param suffix The suffix to be appended.

python/tvm/tir/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,17 @@ def ConvertBlocksToOpaque():
846846
return _ffi_api.ConvertBlocksToOpaque() # type: ignore
847847

848848

849+
def LiftThreadBinding():
850+
"""Lift the same thread bindings to their LCA loops.
851+
852+
Returns
853+
-------
854+
fpass : tvm.transform.Pass
855+
The result pass
856+
"""
857+
return _ffi_api.LiftThreadBinding() # type: ignore
858+
859+
849860
def CompactBufferAllocation(is_strict: bool = True):
850861
"""Compact the buffer access region. by removing the buffer regions
851862
that are not accessed, i.e. narrowing the buffer shape and adjust

src/driver/driver_api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
202202
pass_list.push_back(tir::transform::LowerInitBlock());
203203
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
204204
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
205+
pass_list.push_back(tir::transform::LiftThreadBinding());
205206
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
206207
pass_list.push_back(tir::transform::CompactBufferAllocation());
207208
pass_list.push_back(tir::transform::LowerAutoCopy());

src/meta_schedule/postproc/verify_gpu_code.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class VerifyGPUCodeNode : public PostprocNode {
161161
pass_list.push_back(tir::transform::LowerInitBlock());
162162
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
163163
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
164+
pass_list.push_back(tir::transform::LiftThreadBinding());
164165
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
165166
pass_list.push_back(tir::transform::CompactBufferAllocation());
166167
pass_list.push_back(tir::transform::Simplify());

src/tir/ir/expr.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,22 @@ Var::Var(String name_hint, Type type_annotation, Span span) {
8080
data_ = std::move(n);
8181
}
8282

83-
Var Var::copy_with_suffix(const String& suffix) const {
83+
Var Var::copy_with_name(const String& name) const {
8484
const VarNode* node = get();
8585
ObjectPtr<VarNode> new_ptr;
8686
if (auto* ptr = this->as<SizeVarNode>()) {
8787
new_ptr = make_object<SizeVarNode>(*ptr);
8888
} else {
8989
new_ptr = make_object<VarNode>(*node);
9090
}
91-
new_ptr->name_hint = new_ptr->name_hint + suffix;
91+
new_ptr->name_hint = name;
9292
return Var(new_ptr);
9393
}
9494

95+
Var Var::copy_with_suffix(const String& suffix) const {
96+
return this->copy_with_name(get()->name_hint + suffix);
97+
}
98+
9599
Var Var::copy_with_dtype(DataType dtype) const {
96100
const VarNode* node = get();
97101
ObjectPtr<VarNode> new_ptr;
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file convert_block_to_opaque.cc
22+
* \brief Convert the blocks to opaque blocks which do not have block vars.
23+
*/
24+
25+
#include <tvm/tir/stmt_functor.h>
26+
#include <tvm/tir/transform.h>
27+
28+
#include "../../runtime/thread_storage_scope.h"
29+
#include "./ir_utils.h"
30+
31+
namespace tvm {
32+
namespace tir {
33+
34+
std::pair<std::unordered_map<Stmt, std::vector<std::pair<IterVar, Map<String, ObjectRef>>>,
35+
ObjectPtrHash, ObjectPtrEqual>,
36+
Map<Var, Var>>
37+
FindLoopLCA(const Stmt& root) {
38+
class LCAFinder : public StmtVisitor {
39+
public:
40+
void VisitStmt_(const ForNode* op) final {
41+
stack.push_back(GetRef<Stmt>(op));
42+
StmtVisitor::VisitStmt_(op);
43+
if (op->kind == ForKind::kThreadBinding) {
44+
UpdateLCA(op);
45+
}
46+
stack.pop_back();
47+
}
48+
49+
void UpdateLCA(const ForNode* loop) {
50+
std::string thread_tag = loop->thread_binding.value()->thread_tag;
51+
{
52+
Map<String, ObjectRef>* tgt = &annotations[thread_tag];
53+
for (const auto& kv : loop->annotations) {
54+
tgt->Set(kv.first, kv.second);
55+
}
56+
}
57+
IterVar& iter_var = iters[thread_tag];
58+
if (!iter_var.defined()) {
59+
iter_var = IterVar(Range::FromMinExtent(loop->min, loop->extent), //
60+
loop->loop_var.copy_with_name(thread_tag), //
61+
loop->thread_binding.value()->iter_type, //
62+
thread_tag);
63+
lca[thread_tag] = stack;
64+
var_subst.Set(loop->loop_var, iter_var->var);
65+
return;
66+
}
67+
var_subst.Set(loop->loop_var, iter_var->var);
68+
std::vector<Stmt>& path = lca[thread_tag];
69+
uint32_t i = 0;
70+
for (; i < stack.size() && i < path.size(); ++i) {
71+
if (!stack[i].same_as(path[i])) {
72+
break;
73+
}
74+
}
75+
path.resize(i);
76+
}
77+
78+
std::unordered_map<std::string, std::vector<Stmt>> lca;
79+
std::unordered_map<std::string, IterVar> iters;
80+
std::unordered_map<std::string, Map<String, ObjectRef>> annotations;
81+
Map<Var, Var> var_subst;
82+
std::vector<Stmt> stack;
83+
};
84+
LCAFinder finder;
85+
finder(root);
86+
std::unordered_map<Stmt, std::vector<std::pair<IterVar, Map<String, ObjectRef>>>, ObjectPtrHash,
87+
ObjectPtrEqual>
88+
result;
89+
std::vector<std::string> sorted_thread_tags;
90+
for (const auto& kv : finder.lca) {
91+
sorted_thread_tags.push_back(kv.first);
92+
}
93+
std::sort(sorted_thread_tags.begin(), sorted_thread_tags.end(),
94+
[](const std::string& lhs, const std::string& rhs) {
95+
return lhs.size() > rhs.size();
96+
runtime::ThreadScope lhs_scope = runtime::ThreadScope::Create(lhs);
97+
runtime::ThreadScope rhs_scope = runtime::ThreadScope::Create(rhs);
98+
if (lhs_scope.rank != rhs_scope.rank) {
99+
return lhs_scope.rank < rhs_scope.rank;
100+
}
101+
return lhs_scope.dim_index < rhs_scope.dim_index;
102+
});
103+
for (const auto& thread_tag : sorted_thread_tags) {
104+
Stmt lca = finder.lca[thread_tag].back();
105+
const IterVar& iter = finder.iters[thread_tag];
106+
const Map<String, ObjectRef>& annotations = finder.annotations[thread_tag];
107+
result[lca].emplace_back(iter, annotations);
108+
}
109+
return {result, finder.var_subst};
110+
}
111+
112+
/*!
113+
* \brief Substitute expr via BlockRealize value bindings and convert each block into opaque
114+
* blocks.
115+
*/
116+
class ThreadBindingLifter : public StmtExprMutator {
117+
public:
118+
Stmt VisitStmt_(const ForNode* _op) final {
119+
For op = GetRef<For>(_op);
120+
bool is_kernel_root = false;
121+
if (op->kind == ForKind::kThreadBinding) {
122+
if (iter_lca.empty()) {
123+
is_kernel_root = true;
124+
SetKernelRoot(_op);
125+
}
126+
}
127+
For new_op = Downcast<For>(StmtExprMutator::VisitStmt_(_op));
128+
Stmt body = std::move(new_op.CopyOnWrite()->body);
129+
if (auto it = iter_lca.find(op); it != iter_lca.end()) {
130+
for (const auto& [iter_var, annotation] : it->second) {
131+
body = For(iter_var->var, iter_var->dom->min, iter_var->dom->extent,
132+
ForKind::kThreadBinding, std::move(body),
133+
IterVar(Range(nullptr), Var(iter_var->thread_tag, iter_var->var->dtype),
134+
kThreadIndex, iter_var->thread_tag),
135+
annotation);
136+
}
137+
}
138+
if (is_kernel_root) {
139+
iter_lca.clear();
140+
var_subst.clear();
141+
}
142+
if (op->kind == ForKind::kThreadBinding) {
143+
return body;
144+
} else {
145+
new_op.CopyOnWrite()->body = std::move(body);
146+
return new_op;
147+
}
148+
}
149+
150+
void SetKernelRoot(const ForNode* op) {
151+
auto result = FindLoopLCA(GetRef<Stmt>(op));
152+
this->iter_lca = std::move(result.first);
153+
this->var_subst = std::move(result.second);
154+
}
155+
156+
PrimExpr VisitExpr_(const VarNode* op) final {
157+
auto it = var_subst.find(GetRef<Var>(op));
158+
if (it != var_subst.end()) {
159+
return (*it).second;
160+
} else {
161+
return GetRef<PrimExpr>(op);
162+
}
163+
}
164+
165+
std::unordered_map<Stmt, std::vector<std::pair<IterVar, Map<String, ObjectRef>>>, ObjectPtrHash,
166+
ObjectPtrEqual>
167+
iter_lca;
168+
Map<Var, Var> var_subst;
169+
};
170+
171+
PrimFunc LiftThreadBinding(PrimFunc f) {
172+
// Only apply this pass to TIR that is not from TE schedules
173+
if (!IsFromLegacyTESchedule(f)) {
174+
PrimFuncNode* fptr = f.CopyOnWrite();
175+
fptr->body = ThreadBindingLifter()(std::move(fptr->body));
176+
return f;
177+
} else {
178+
return f;
179+
}
180+
}
181+
182+
namespace transform {
183+
184+
Pass LiftThreadBinding() {
185+
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
186+
return LiftThreadBinding(std::move(f));
187+
};
188+
return CreatePrimFuncPass(pass_func, 0, "tir.LiftThreadBinding", {});
189+
}
190+
191+
TVM_REGISTER_GLOBAL("tir.transform.LiftThreadBinding").set_body_typed(LiftThreadBinding);
192+
} // namespace transform
193+
194+
} // namespace tir
195+
} // namespace tvm

0 commit comments

Comments
 (0)